diff --git a/.gitignore b/.gitignore index 828bbe9bd3363853ae3f58f54a8d5f60cefad837..b5306b8b79c37166e5496cf17a3e39b86b9a6314 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ __pycache__ cmake_build/ .idea/** /build/ +[Bb]uild/ /tensorflow/core/util/version_info.cc /tensorflow/python/framework/fast_tensor_util.cpp Pods diff --git a/README.md b/README.md index 63853137cfd30b396f8c7d204811f3e4a1794c07..05fcb23f7edd657f2ea495d848fadc226e56b524 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), a data visualization toolkit. +code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -96,6 +96,8 @@ The TensorFlow project strives to abide by generally accepted best practices in | --- | --- | --- | | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | +| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | +| **Linux CPU with IntelĀ® MKL-DNNĀ®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA | ## For more information diff --git a/RELEASE.md b/RELEASE.md index e09e9c6190f57adec67c2ae1d85848dabfd9c2a7..7e6325af14d007a39d272817e2c4d476da9ce119 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,18 +1,38 @@ # Release 1.9.0 ## Major Features And Improvements -* Update tf.keras to the Keras 2.1.6 API. -* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. -* Adding support of core feature columns and losses to gradient boosted trees estimators. -* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details. -* Layered variable names have changed in the following conditions: - * Using `tf.keras.layers` with custom variable scopes. - * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details - +* Updated docs for `tf.keras`: New Keras-based [get started](http://tensorflow.org/versions/r1.9/get_started), + and [programmers guide page](http://tensorflow.org/versions/r1.9/programmers_guide/keras). +* Update `tf.keras` to the Keras 2.1.6 API. +* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082). +* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees). +* The [python interface](https://tensorflow-dot-devsite.googleplex.com/versions/r1.9/api_docs/python/tf/contrib/lite) + for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md) + has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again + included in the standard `pip` installation. +* Improved data-loading and text processing with: + * [`tf.decode_compressed`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/decode_compressed) + * [`tf.string_strip`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/string_strip) + * [`tf.strings.regex_full_match`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/strings/regex_full_match) +* Added experimental support for new pre-made Estimators: + * [`tf.contrib.estimator.BaselineEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/BaselineEstimator) + * [`tf.contrib.estimator.RNNClassifier`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNEstimator) + * [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier) +* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector) + API supports broadcasting for Bijectors with new API changes. + ## Breaking Chances - * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). + * If you're opening empty variable scopes; replace `variable_scope('', ...)` by + `variable_scope(tf.get_variable_scope(), ...)`. + * Headers used for building custom ops have been moved from site-packages/external into site-packages/tensorflow/include/external. ## Bug Fixes and Other Changes + +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See + [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details * `tf.data`: * The `DatasetBase::DebugString()` method is now `const`. * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. @@ -465,7 +485,7 @@ answered questions, and were part of inspiring discussions. ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of +* [`tf.data`](http://tensorflow.org/guide/datasets) is now part of the core TensorFlow API. * The API is now subject to backwards compatibility guarantees. * For a guide to migrating from the `tf.contrib.data` API, see the @@ -485,7 +505,7 @@ answered questions, and were part of inspiring discussions. * TensorFlow Debugger (tfdbg): * Add `eval` command to allow evaluation of arbitrary Python/numpy expressions in tfdbg command-line interface. See - [Debugging TensorFlow Programs](https://www.tensorflow.org/programmers_guide/debugger) + [Debugging TensorFlow Programs](https://www.tensorflow.org/guide/debugger) for more details. * Usability improvement: The frequently used tensor filter `has_inf_or_nan` is now added to `Session` wrappers and hooks by default. So there is no need @@ -772,7 +792,7 @@ answered questions, and were part of inspiring discussions. * Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters. * TensorFlow C library now available for Windows. * We released a new open-source version of TensorBoard. -* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel +* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel * Android releases of TensorFlow are now pushed to jcenter for easier integration into apps. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md diff --git a/configure.py b/configure.py index ada342a50ab5104509156d3e44e6435a308255a3..5243e09b244fea89c0d36cea73b93309aef7e595 100644 --- a/configure.py +++ b/configure.py @@ -943,6 +943,35 @@ def set_tf_cudnn_version(environ_cp): write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) +def is_cuda_compatible(lib, cuda_ver, cudnn_ver): + """Check compatibility between given library and cudnn/cudart libraries.""" + ldd_bin = which('ldd') or '/usr/bin/ldd' + ldd_out = run_shell([ldd_bin, lib], True) + ldd_out = ldd_out.split(os.linesep) + cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') + cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') + cudnn = None + cudart = None + cudnn_ok = True # assume no cudnn dependency by default + cuda_ok = True # assume no cuda dependency by default + for line in ldd_out: + if 'libcudnn.so' in line: + cudnn = cudnn_pattern.search(line) + cudnn_ok = False + elif 'libcudart.so' in line: + cudart = cuda_pattern.search(line) + cuda_ok = False + if cudnn and len(cudnn.group(1)): + cudnn = convert_version_to_int(cudnn.group(1)) + if cudart and len(cudart.group(1)): + cudart = convert_version_to_int(cudart.group(1)) + if cudnn is not None: + cudnn_ok = (cudnn == cudnn_ver) + if cudart is not None: + cuda_ok = (cudart == cuda_ver) + return cudnn_ok and cuda_ok + + def set_tf_tensorrt_install_path(environ_cp): """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. @@ -959,8 +988,8 @@ def set_tf_tensorrt_install_path(environ_cp): raise ValueError('Currently TensorRT is only supported on Linux platform.') # Ask user whether to add TensorRT support. - if str(int(get_var( - environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': + if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', + False))) != '1': return for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): @@ -973,47 +1002,29 @@ def set_tf_tensorrt_install_path(environ_cp): # Result returned from "read" will be used unexpanded. That make "~" # unusable. Going through one more level of expansion to handle that. - trt_install_path = os.path.realpath( - os.path.expanduser(trt_install_path)) + trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path)) def find_libs(search_path): """Search for libnvinfer.so in "search_path".""" fl = set() if os.path.exists(search_path) and os.path.isdir(search_path): - fl.update([os.path.realpath(os.path.join(search_path, x)) - for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + fl.update([ + os.path.realpath(os.path.join(search_path, x)) + for x in os.listdir(search_path) + if 'libnvinfer.so' in x + ]) return fl possible_files = find_libs(trt_install_path) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) - - def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): - """Check the compatibility between tensorrt and cudnn/cudart libraries.""" - ldd_bin = which('ldd') or '/usr/bin/ldd' - ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) - cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') - cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') - cudnn = None - cudart = None - for line in ldd_out: - if 'libcudnn.so' in line: - cudnn = cudnn_pattern.search(line) - elif 'libcudart.so' in line: - cudart = cuda_pattern.search(line) - if cudnn and len(cudnn.group(1)): - cudnn = convert_version_to_int(cudnn.group(1)) - if cudart and len(cudart.group(1)): - cudart = convert_version_to_int(cudart.group(1)) - return (cudnn == cudnn_ver) and (cudart == cuda_ver) - cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') highest_ver = [0, None, None] for lib_file in possible_files: - if is_compatible(lib_file, cuda_ver, cudnn_ver): + if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) if len(matches.groups()) == 0: continue @@ -1029,12 +1040,13 @@ def set_tf_tensorrt_install_path(environ_cp): # Try another alternative from ldconfig. ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' ldconfig_output = run_shell([ldconfig_bin, '-p']) - search_result = re.search( - '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) + search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)', + ldconfig_output) if search_result: libnvinfer_path_from_ldconfig = search_result.group(2) if os.path.exists(libnvinfer_path_from_ldconfig): - if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): + if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver, + cudnn_ver): trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) tf_tensorrt_version = search_result.group(1) break @@ -1122,7 +1134,9 @@ def set_tf_nccl_install_path(environ_cp): nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + nccl_license_path = os.path.join(nccl_install_path, 'NCCL-SLA.txt') + if os.path.exists(nccl_lib_path) and os.path.exists( + nccl_hdr_path) and os.path.exists(nccl_license_path): # Set NCCL_INSTALL_PATH environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 4e212e96dcfe4ad2b2055ea9abb150e9fd5c1f28..f362900387e506e935d4ede9aa781a83948fe0da 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -257,6 +257,13 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_cuda_support_windows_override", + define_values = {"using_cuda_nvcc": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support_android_override", define_values = {"with_gcp_support": "true"}, @@ -404,6 +411,7 @@ config_setting( package_group( name = "internal", packages = [ + "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", @@ -578,11 +586,20 @@ gen_api_init_files( py_library( name = "tensorflow_py", - srcs = [ - ":tensorflow_python_api_gen", - "//tensorflow/python/estimator/api:estimator_python_api_gen", + srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":tensorflow_py_no_contrib", + "//tensorflow/contrib:contrib_py", + "//tensorflow/python/estimator:estimator_py", ], +) + +py_library( + name = "tensorflow_py_no_contrib", + srcs = [":tensorflow_python_api_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = ["//tensorflow/python"], + deps = ["//tensorflow/python:no_contrib"], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 9662d7b478ba61c69edc20b0d47293f9939e7881..779f65d5b17c350833f67f07985b00e8eb561e72 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -20,7 +20,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -# API IMPORTS PLACEHOLDER try: import os # pylint: disable=g-import-not-at-top @@ -37,6 +36,8 @@ try: except (ImportError, AttributeError): print('tf.estimator package not installed.') +# API IMPORTS PLACEHOLDER + from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 12f0d8bff4720d98b7f45b113dc62c881e32a399..5c218d3f25e01f0e78916d4a5a8b1d2751f9dc25 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -390,64 +391,6 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } -// This traverses the specified nodes in topological order to verify there are -// no cycles. Starting with inputless nodes, it visits nodes whose inputs have -// all been visited, and counts the total number of visited nodes. If there is a -// cycle, nodes in the cycle will never be visited, and the visited count will -// be less than the total node count. -Status ValidateNoCycles(const Graph& g) { - // TODO(nolivia): check this on a subset of the graph instead of all of it. - // A node is ready when all of its inputs have been visited. - std::vector ready; - std::vector pending_count(g.num_node_ids(), 0); - - for (int i = 0; i < g.num_node_ids(); ++i) { - const Node* n = g.FindNodeId(i); - if (n == nullptr) continue; - pending_count[i] = n->in_edges().size(); - if (n->IsMerge()) { - // While-loop cycles are legal cycles so we manually adjust the - // pending_count to make sure that the loop is visited. - for (const Edge* e : n->in_edges()) { - if (!e->IsControlEdge() && e->src()->IsNextIteration()) { - pending_count[i]--; - } - } - } - if (pending_count[i] == 0) { - ready.push_back(n); - } - } - - int processed = 0; - while (!ready.empty()) { - const Node* node = ready.back(); - ready.pop_back(); - ++processed; - - for (const Edge* out : node->out_edges()) { - const int output_id = out->dst()->id(); - pending_count[output_id]--; - if (pending_count[output_id] == 0) { - ready.push_back(out->dst()); - } - } - } - - if (processed < g.num_nodes()) { - std::vector nodes_in_cycle; - for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; - ++i) { - if (pending_count[i] != 0) { - nodes_in_cycle.push_back(g.FindNodeId(i)->name()); - } - } - return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, - " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); - } - return Status::OK(); -} } // namespace } // namespace tensorflow @@ -746,7 +689,9 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); + // TODO(nolivia): check this on a subset of the graph instead of all of + // it. + status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); if (!status->status.ok()) { session->graph->mu.unlock(); return false; @@ -2123,7 +2068,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return nullptr; } @@ -2153,7 +2099,8 @@ void TF_GraphImportGraphDefWithReturnOutputs( return; } GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } @@ -2469,7 +2416,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; - g->name_map[n->name()] = n; + // We have a convoluted scheme here: Using the C++ graph construction API + // to add potentially many nodes to the graph without running the checks + // (such as uniqueness of the names of nodes) we run with other functions + // that add a node to the graph (like TF_FinishOperation). + if (!g->name_map.insert(std::make_pair(n->name(), n)).second) { + status->status = tensorflow::errors::Internal( + "BUG: The API allowed construction of a graph with duplicate node " + "names (", + n->name(), + "). This is a bug. Please file an issue at " + "https://github.com/tensorflow/tensorflow/issues."); + } } } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c8594347451dffd465d7fa926cc53818dc9e38d4..1eb75ef11ff337dfcb2e016e09804fc04662fcda 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -894,7 +894,8 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); // Set the prefix to be prepended to the names of nodes in `graph_def` that will -// be imported into `graph`. +// be imported into `graph`. `prefix` is copied and has no lifetime +// requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); @@ -915,6 +916,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. +// `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst); @@ -922,7 +924,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( // Set any imported nodes with control input `src_name` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references an operation already existing in the graph being imported -// into. +// into. `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); @@ -934,6 +936,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( // Add an output in `graph_def` to be returned via the `return_outputs` output // parameter of TF_GraphImportGraphDef(). If the output is remapped via an input // mapping, the corresponding existing tensor in `graph` will be returned. +// `oper_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_ImportGraphDefOptions* opts, const char* oper_name, int index); @@ -943,7 +946,8 @@ TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); // Add an operation in `graph_def` to be returned via the `return_opers` output -// parameter of TF_GraphImportGraphDef(). +// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no +// lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( TF_ImportGraphDefOptions* opts, const char* oper_name); diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 577f10c5e69ea9ecbe8ce821c6bd5167e98bef25..bc04b53fbb7fa9ba46228ae5a4ec8ee96df5f3dc 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1160,7 +1160,7 @@ TEST(CAPI, GetOpDef) { } void StringVectorToArrays(const std::vector& v, - std::unique_ptr* ptrs, + std::unique_ptr* ptrs, std::unique_ptr* lens) { ptrs->reset(new const void*[v.size()]); lens->reset(new size_t[v.size()]); @@ -1196,7 +1196,7 @@ class CApiColocationTest : public ::testing::Test { void SetViaStringList(TF_OperationDescription* desc, const std::vector& list) { - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(), @@ -1700,6 +1700,61 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { + ASSERT_TRUE(t != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); + ASSERT_EQ(0, TF_NumDims(t)); + ASSERT_EQ(4, TF_TensorByteSize(t)); + float* p = static_cast(TF_TensorData(t)); + *f = *p; +} + +TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) { + const float X = 3.0f, Y = 7.0f; + TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT); + TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT); + TF_Operation* xy = Mul(x, y, graph_, s_, "xy"); + TF_Output dxy_dx, dxy_dy; + + TF_Output outputs[1] = {{xy, 0}}; + TF_Output inputs[1] = {{x, 0}}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + inputs[0] = {y, 0}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph_, opts, s_); + TF_DeleteSessionOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Output feeds[] = {{x, 0}, {y, 0}}; + TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)}; + TF_Output fetches[] = {dxy_dx, dxy_dy}; + TF_Tensor* fetchValues[] = {nullptr, nullptr}; + + TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches, + fetchValues, 2, nullptr /* target_opers */, 0, + nullptr /* run_metadata */, s_); + TF_DeleteTensor(feedValues[0]); + TF_DeleteTensor(feedValues[1]); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_DeleteSession(sess, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f; + ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue); + EXPECT_EQ(Y, dxy_dxValue); + + ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue); + EXPECT_EQ(X, dxy_dyValue); + + TF_DeleteTensor(fetchValues[0]); + TF_DeleteTensor(fetchValues[1]); +} + // REGISTER_OP for CApiAttributesTest test cases. // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other @@ -1784,7 +1839,7 @@ TEST_F(CApiAttributesTest, String) { TEST_F(CApiAttributesTest, StringList) { std::vector list = {"bugs", "bunny", "duck"}; - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); int list_total_size = 0; @@ -1800,7 +1855,7 @@ TEST_F(CApiAttributesTest, StringList) { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size); - std::unique_ptr values(new void*[list.size()]); + std::unique_ptr values(new void*[list.size()]); std::unique_ptr lens(new size_t[list.size()]); std::unique_ptr storage(new char[list_total_size]); TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(), @@ -2025,7 +2080,7 @@ TEST_F(CApiAttributesTest, TensorShapeProtoList) { tensorflow::PartialTensorShape(pts2).AsProto(&proto); proto.SerializeToString(&bytes2); - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; const std::vector list = {bytes1, bytes2}; StringVectorToArrays(list, &list_ptrs, &list_lens); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index f3b28c1708129d39e451d927a89c0d10e2193b63..24eb6c069b21349fce288db3e79fbf14e824ad11 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -216,6 +216,13 @@ TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, return MinWithDevice(l, r, graph, /*op_device=*/"", s, name); } +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_Operation* op; + BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true); + return op; +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index c16aba666ee6974fed5351c2d9ac291dcbcdecab..38313d647ca93d4779bb1325f8ed7bde4b743879 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -80,6 +80,9 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "min"); +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "mul"); + // If `op_device` is non-empty, set the created op on that device. TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, const string& op_device, TF_Status* s, diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f265da2c2c89c0e9caf14f2213c606fcb69997e0..37be52f57d865c1e59611540d5dab04b59e89444 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -54,7 +54,6 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", @@ -93,10 +92,10 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) @@ -122,6 +121,7 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", + size = "small", srcs = [ "c_api_debug_test.cc", "c_api_test.cc", @@ -139,7 +139,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 81221c4078bec9820ee187efdf0314da378be62b..82ca2be2cff885967dd798a1cb84b164a9df399e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -36,9 +36,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" @@ -46,10 +46,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -107,7 +109,8 @@ tensorflow::Status GetAllRemoteDevices( } tensorflow::Status CreateRemoteContexts( - const std::vector& remote_workers, + const std::vector& remote_workers, int64 rendezvous_id, + const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -115,12 +118,14 @@ tensorflow::Status CreateRemoteContexts( tensorflow::eager::CreateContextRequest request; tensorflow::eager::CreateContextResponse response; + request.set_rendezvous_id(rendezvous_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, &parsed_name)) { return tensorflow::errors::InvalidArgument( "Unable to parse ", remote_worker, " as a device name"); } + *request.mutable_server_def() = server_def; request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); @@ -147,46 +152,82 @@ tensorflow::Status CreateRemoteContexts( tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TFE_Context** ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + string worker_name = tensorflow::strings::StrCat( "/job:", opts->server_def.job_name(), "/replica:0/task:", opts->server_def.task_index()); - std::unique_ptr server; - TF_RETURN_IF_ERROR( - tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); - TF_RETURN_IF_ERROR(server->Start()); + std::unique_ptr server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } + + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + + int64 rendezvous_id = tensorflow::random::New64(); std::vector remote_workers; - server->master_env()->worker_cache->ListWorkers(&remote_workers); + grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); remote_workers.erase( std::remove(remote_workers.begin(), remote_workers.end(), worker_name), remote_workers.end()); std::unique_ptr remote_device_mgr; - TF_RETURN_IF_ERROR(GetAllRemoteDevices( - remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, grpc_server->master_env()->worker_cache, + &remote_device_mgr)); std::shared_ptr channel_cache = - server->channel_cache(); + grpc_server->channel_cache(); std::unique_ptr remote_eager_workers( tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; - TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, - remote_eager_workers.get(), - opts->async, &remote_contexts)); + LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + remote_workers, rendezvous_id, opts->server_def, + remote_eager_workers.get(), opts->async, &remote_contexts)); tensorflow::RemoteRendezvous* r = - server->worker_env()->rendezvous_mgr->Find(0); + grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); + + auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); + TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( + session_name, opts->server_def, true)); + + std::shared_ptr worker_session; + TF_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->WorkerSessionForSession( + session_name, &worker_session)); + + // Initialize remote tensor communication based on worker session. + TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); - auto* device_mgr = server->worker_env()->device_mgr; + auto* device_mgr = grpc_server->worker_env()->device_mgr; *ctx = new TFE_Context(opts->session_options.options, opts->policy, opts->async, device_mgr, r, std::move(server), std::move(remote_eager_workers), std::move(remote_device_mgr), remote_contexts); return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR } } // namespace @@ -307,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dims(); + int result; + status->status = h->handle->NumDims(&result); + return result; } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dim_size(dim_index); + tensorflow::int64 result; + status->status = h->handle->Dim(dim_index, &result); + return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { @@ -421,8 +462,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, return ret; } -void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { - op->operation.MutableAttrs()->Set(attr_name, value); +void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, + size_t length) { + op->operation.MutableAttrs()->Set( + attr_name, + tensorflow::StringPiece(static_cast(value), length)); } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { @@ -473,16 +517,22 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } -#define TFE_OP_SET_ATTR_LIST(fn, type) \ - void fn(TFE_Op* op, const char* attr_name, const type* values, \ - int num_values) { \ - op->operation.MutableAttrs()->Set( \ - attr_name, \ - tensorflow::gtl::ArraySlice(values, num_values)); \ +void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = tensorflow::StringPiece(static_cast(values[i]), + lengths[i]); } -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) -#undef TFE_OP_SET_ATTR_LIST + op->operation.MutableAttrs()->Set(attr_name, v); +} + +void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, + const float* values, int num_values) { + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice(values, num_values)); +} void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { @@ -655,9 +705,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, const char* attr_name, TF_Status* status) { switch (default_value.value_case()) { - case tensorflow::AttrValue::kS: - TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + case tensorflow::AttrValue::kS: { + const string& v = default_value.s(); + TFE_OpSetAttrString(op, attr_name, v.data(), v.size()); break; + } case tensorflow::AttrValue::kI: TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); break; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 1862af3ce2f505a6e83b4805417eaf335ed07bc0..fdbd5374b2afe815c3a81b453930eb8f1fa351d3 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -278,7 +278,8 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, - const char* value); + const void* value, + size_t length); TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, @@ -305,7 +306,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, + const void* const* values, + const size_t* lengths, int num_values); TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 04a6efc47c5177c82b7e88168b67cc584587de7c..4c5077023d5bb3b83808bf3908e7110dd026e3ad 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" @@ -78,7 +78,7 @@ struct TFE_Context { TFE_ContextDevicePlacementPolicy default_policy, bool async, tensorflow::DeviceMgr* local_device_mgr, tensorflow::Rendezvous* rendezvous, - std::unique_ptr server, + std::unique_ptr server, std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_mgr, const tensorflow::gtl::FlatMap& diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 992d1afd5fcb0641794bb2abbe5ab20a287d3b62..3504a8b5e78480732d3454097c1b2197ac2b2e17 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -132,10 +132,10 @@ void TestRemoteExecute(bool async) { server_def.set_task_index(1); - std::unique_ptr worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); ASSERT_TRUE(worker_server->Start().ok()); TF_Status* status = TF_NewStatus(); @@ -143,7 +143,7 @@ void TestRemoteExecute(bool async) { TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -208,25 +208,31 @@ TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } void TestRemoteExecuteSilentCopies(bool async) { - tensorflow::ServerDef server_def = GetServerDef(2); + tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. string serialized = server_def.SerializeAsString(); server_def.set_task_index(1); - - std::unique_ptr worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); - ASSERT_TRUE(worker_server->Start().ok()); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -234,12 +240,16 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); - const char remote_device_name[] = - "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; - // Handles are on task0, but op is on remote (task1). - TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); - TFE_OpSetDevice(matmul, remote_device_name, status); + auto* h1_task2 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Handles are on task0 (local), and task2, but op is on task1. + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); + TFE_OpSetDevice(matmul, task1_name, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1]; @@ -265,6 +275,7 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_DeleteTensorHandle(h0_task0); TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h1_task2); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteOp(matmul); @@ -276,7 +287,8 @@ void TestRemoteExecuteSilentCopies(bool async) { TF_DeleteStatus(status); // TODO(nareshmodi): Figure out how to correctly shut the server down. - worker_server.release(); + worker_server1.release(); + worker_server2.release(); } TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } @@ -1162,8 +1174,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", ""); - TFE_OpSetAttrString(op, "shared_name", ""); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); if (TF_GetCode(status) != TF_OK) return nullptr; TFE_TensorHandle* var_handle = nullptr; int num_retvals = 1; diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 079e063d3e3fbdaf833e9031f5f9438853c14099..a98f0b00b2c70055f697ed4f15cb14708384b62f 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -530,7 +530,7 @@ cc_library_with_android_deps( "//tensorflow/core/api_def:base_api_def", ], deps = [ - "//tensorflow/core:framework", + "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 62a889181e787f2e181135ab0563c45e1bab8812..8c886f31711eb014fb9e9d600c9c78cf22073f71 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -37,6 +37,11 @@ Scope& Scope::operator=(const Scope& other) { return *this; } +namespace { +const char kScopeSeparator[] = "/"; +const char kSuffixSeparator[] = "_"; +} // namespace + Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, bool disable_shape_inference) : graph_(graph), @@ -308,19 +313,23 @@ string Scope::Impl::GetUniqueName(const string& prefix, return prefix; } auto entry = name_map_->find(prefix); - string unique_name = prefix; if (entry == name_map_->end()) { name_map_->insert({prefix, 0}); - } else { - unique_name = strings::StrCat(unique_name, "_", ++entry->second); + return prefix; } + string unique_name; + do { + unique_name = strings::StrCat(prefix, kSuffixSeparator, ++entry->second); + } while (name_map_->find(unique_name) != name_map_->end()); + name_map_->insert({unique_name, 0}); return unique_name; } string Scope::Impl::GetNameForOp(const string& default_name) const { const string unique_name = GetUniqueName(default_name, true /* check_single_use */); - const string sep = name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return strings::StrCat(name_, sep, unique_name); } @@ -345,7 +354,8 @@ Scope Scope::NewSubScope(const string& child_scope_name) const { } const string unique_name = impl()->GetUniqueName(child_scope_name, false /* check_single_use */); - const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + impl()->name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return Scope(new Impl(*this, Impl::Tags::ScopeName(), strings::StrCat(impl()->name_, sep, unique_name), false /* copy_names */)); @@ -412,7 +422,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes( if (!impl()->single_use_scope()) { Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name : impl()->op_name_); - const string child_op_sep = impl()->name_.empty() ? "" : "_"; + const string child_op_sep = impl()->name_.empty() ? "" : kSuffixSeparator; const string child_name = strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_); return {child, @@ -435,7 +445,13 @@ class InternalScope { static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; for (const Node* node : graph->nodes()) { - (*name_map)[node->name()] = 0; + const string& name = node->name(); + (*name_map)[name] = 0; + // Add all name prefixes ('/' separated). + size_t idx = -1; + while ((idx = name.find(kScopeSeparator, idx + 1)) != string::npos) { + (*name_map)[name.substr(0, idx)] = 0; + } } // We provide null destructors for these shared ptrs (except for name_map) // since the caller owns them and doesn't want the scope to destroy them. diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 8efcfed20d0b86d86d8c20a3d8630c7c6bc909c3..58adaef2e942a7fa6b0ce8d5534ac3e2fd380580 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -34,8 +34,7 @@ class Scope::Impl { // name that has not been used so far in a scope will get no suffix. Later // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. + // WithControlDependencies() would share the same NameMap with the parent. typedef std::unordered_map NameMap; Impl(const std::shared_ptr& graph, diff --git a/tensorflow/cc/framework/scope_test.cc b/tensorflow/cc/framework/scope_test.cc index 9eca9d3face34319413e1acbc2f5ac0b2ba85374..b40b345eb84237c34ea593021bea022ad28095f7 100644 --- a/tensorflow/cc/framework/scope_test.cc +++ b/tensorflow/cc/framework/scope_test.cc @@ -26,6 +26,16 @@ TEST(ScopeTest, BasicNames) { EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul"); } +TEST(ScopeTest, OpAndScopeNameCollision) { + Scope root = Scope::NewRootScope(); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_1"), "foo_1_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_3"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2_1"); +} + TEST(ScopeTest, HierarchicalNames) { Scope root = Scope::NewRootScope(); Scope child = root.NewSubScope("child"); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index ff348fadb24e29a83bd6c8853aa67931f6df4182..b353accddcb6db9a07c112de03ead2f02c4ee6a6 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -421,6 +421,58 @@ Status StridedSliceGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper); +Status SliceGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // Propagate the incoming gradient along all the selected values, + // and zero everywhere else. Use the Pad operator for this. + // + // First create an Nx2 padding where N is the number of input + // dimensions. The first column is the number of prepended zeros + // for each dimension, and the second column is the number of + // appended zeros. + // + // The first column is just the begin vector. + // The second column is the shape of the input element-wise + // subtracted by begin+size + + // Running example: + // input.shape = [3, 5, 3] + // begin = [1, 2, 1], size = [1, 3, 2] + Input input = op.input(0); + Input begin = op.input(1); + // input_rank = 3 + auto input_rank = Rank(scope, input); + // slice_size = [1, 3, 2] + auto slice_size = Shape(scope, op.output(0)); + // padding_shape = [3, 1] + auto padding_shape = Stack(scope, {input_rank, 1}); + // before_padding = [[1] + // [2] + // [1]] + Input before_padding = Reshape(scope, begin, padding_shape); + // after_padding_sizes = shape(input) - slice_size - begin + // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1] + // = [1, 0, 0] + auto after_padding_sizes = + Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin); + // after_padding = [[1] + // [0] + // [0]] + Input after_padding = Reshape(scope, after_padding_sizes, padding_shape); + // paddings = [[1 1] + // [2 0] + // [1 0]] + auto paddings = + Concat(scope, {before_padding, after_padding}, Const(scope, 1)); + grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings)); + // Nothing propagated for "begin" and "size" inputs + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Slice", SliceGrad); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index de3bd0fc9e2493f8ff76163f5be6bd4327c58c5a..d09275b6487b4212aa35a0476002f2bb587fa210 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -378,5 +378,12 @@ TEST_F(ArrayGradTest, StridedSliceGrad) { RunTest(x, x_shape, y, {1, 2, 2, 2}); } +TEST_F(ArrayGradTest, SliceGrad) { + TensorShape x_shape({3, 5, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Slice(scope_, x, {1, 2, 1}, {1, 3, 2}); + RunTest(x, x_shape, y, {1, 3, 2}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 0025842aead53973befc794378a26fa8db2ae1cb..28070d60dbbe6dd8f930b8e6509cedcf09f94e11 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -287,7 +287,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); - if (result_index < 0 || result_index > temp_sizes.size()) { + if (result_index < 0 || result_index >= temp_sizes.size()) { return errors::InvalidArgument("result index: ", result_index, " is outside the range of temp sizes: [0,", temp_sizes.size(), ")"); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 8c74014614789758192691ee065f92759a113a7a..c2245b8eae8fd27d96feaf58e26418b92e646910 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -176,9 +176,11 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:queue_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", @@ -398,6 +400,32 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_cluster_util_test", + size = "small", + srcs = [ + "xla_cluster_util_test.cc", + ], + deps = [ + ":common", + ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "xla_launch_util_test", size = "small", diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 731b8ebfdc6262500940274c94a03ae7c0376096..a2e6285339f9ed0bde8d72f5b4752b1ecc22f426 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -66,8 +66,28 @@ class SinglePassSearch { Status CompilationRequested(const FunctionLibraryRuntime& flr, const NodeDef& node_def) { + const FunctionDef* function_def = + flr.GetFunctionLibraryDefinition()->Find(node_def.name()); + if (function_def == nullptr) { + // The node def is not calling a function. Individual ops can be + // run directly using on-demand mode, no need to create XlaLaunch + // kernel for them. + // TODO(b/110359382): Make custom kernel creation return a bool instead of + // status. + // We don't set error messages here to avoid unnecessary string copy. + // Similarly below. + return Status(error::INVALID_ARGUMENT, ""); + } + + // If kXlaCompileAttr is set on the node_def, use its value. + const auto& it = node_def.attr().find(kXlaCompileAttr); + if (it != node_def.attr().end()) { + return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, ""); + } + + // kXlaCompileAttr is not set on node_def, check if it is set on + // FunctionDef. bool xla_compile = false; - // Check if op is marked _XlaCompile=true. Status status = flr.GetFunctionLibraryDefinition()->GetAttr( node_def, kXlaCompileAttr, &xla_compile); if (!status.ok() || !xla_compile) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 9448b8ebde09b73bf26fd8c5ad118105045ff452..e786d41887f1d539fe1ae122275d1c14c77309e8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1504,6 +1504,9 @@ Status Encapsulator::SplitIntoSubgraphs() { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); + // Verify that the graph has well-formed control flow structure. + std::vector dummy; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); } return s; @@ -2519,10 +2522,12 @@ Status EncapsulateSubgraphsPass::Run( return Status::OK(); }; - TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, - /*reuse_existing_functions=*/false, &graph_out, library)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, + library), + "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 902fe27acdec1cb323217e6409fbd02f62177612..251a07304eaeb21f1313d7a6ef6af668f99d8551 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -166,6 +166,14 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; + // Optimization: don't resolve constants. If we resolve constants we never + // emit them on the device, meaning that if they are needed by a following + // computation the host has to transfer them. + compile_options.resolve_compile_time_constants = false; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + OP_REQUIRES_OK( ctx, cache->Compile(options, function_, constant_args, variables, ctx, &kernel, &executable, &compile_options)); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 05b7821b8865d0f210ca9af92370e177d6043e80..a5628b12a27c9ed052e22c784517a07f2c1c059a 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -139,27 +139,32 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { }; for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles->InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); + if (edge->dst()->IsEnter() || edge->src()->IsExit()) { + const char* src_type = "pre-enter"; + const char* dst_type = "post-exit"; + int src = edge->src()->id(); + int dst = edge->dst()->id(); + + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + dst = GetOrAddFrameNodeId(frame_name); + dst_type = "frame"; } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles->InsertEdge(src, edge->dst()->id())) { + + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + src = GetOrAddFrameNodeId(frame_name); + src_type = "frame"; + } + + if (!cycles->InsertEdge(src, dst)) { return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); + "Cycle detected when adding ", src_type, "->", dst_type, + " edge: ", DescribeCycle(cycles, *graph, src, dst)); } // Drop the original edge. continue; diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2cb351e1ecdb4523a8652886af156540e4736b18 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter = + ops::internal::Enter(root.WithOpName("enter"), a, "only_frame"); + Output exit = ops::internal::Exit(root.WithOpName("exit"), enter); + Output b = ops::Add(root.WithOpName("b"), a, exit); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); +} + +TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter_0 = + ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0"); + Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0); + Output enter_1 = + ops::internal::Enter(root.WithOpName("enter_1"), a, "frame_1"); + Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1); + Output b = ops::Add(root.WithOpName("b"), a, exit_1); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 7ed609c43748062656b631243c01d790519c54fd..54a41a4daa790401c797277e7eaab531dd34ac80 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -40,7 +40,23 @@ namespace tensorflow { XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} -XlaCompilationCache::~XlaCompilationCache() = default; +XlaCompilationCache::~XlaCompilationCache() { + // Ensure any use of our programs have completed by waiting for all stream + // executors to complete. + for (auto* executor : client_->backend().stream_executors()) { + bool ok = executor->SynchronizeAllActivity(); + if (!ok) { + LOG(ERROR) << "Error synchronizing activity while waiting for all " + "programs to complete"; + } + } + // TODO(b/110813685): Think about the program ownership model. Programs are + // currently owned by the compilation cache which means we must wait for + // program completion in the destructor. There are multiple compilation caches + // around, which complicates things a little. Perhaps having programs be + // shared_ptrs (an invasive change) would make the model easier to reason + // about? +} string XlaCompilationCache::DebugString() { return "XLA JIT compilation cache"; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index b1943d3e1a7e321b5a3796a0c6e4f2b5d9ac7018..baccea2d6a793df8c5cf8c8941706d41d2c044ca 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -61,14 +61,18 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; TF_RET_CHECK(stream); - VLOG(2) << "Executing computation."; + VLOG(2) << "Executing computation: " << name(); + for (const xla::ShapedBuffer* arg : launch_context.arguments()) { + VLOG(2) << name() << ": " << *arg; + } xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(ctx->step_id()); - auto run_result = executable->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result = + executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); @@ -159,6 +163,13 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; + // Optimization: don't resolve constants. If we resolve constants we never + // emit them on the device, meaning that if they are needed by a following + // computation the host has to transfer them. + compile_options.resolve_compile_time_constants = false; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 71e63b110b3b132a57fc291e53a165954c72a03c..e20f5aa83766ccbdf4c19269cfbb00f9e077c2ef 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -67,36 +67,53 @@ Status XlaTransferManager::TransferLiteralToDevice( xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); - xla::BorrowingLiteral literal( + // Create a reference to hold onto host_tensor until after the literal has + // been transferred. Also make sure the literal exists until the function + // asynchronously completes, as it will be wrapped in an xla::LiteralSlice. + TensorReference ref(host_tensor); + auto literal = std::make_shared( static_cast(DMAHelper::base(&host_tensor)), xla_shape); const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(device_tensor)->shaped_buffer(); - VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); - return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, - shaped_buffer); + TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( + stream_, *literal, shaped_buffer)); + // Unref the host tensor, and capture the literal shared_ptr too so it goes + // out of scope when the lambda completes. + stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } -Status XlaTransferManager::TransferLiteralFromDevice( - Tensor* host_tensor, const Tensor& device_tensor) const { +void XlaTransferManager::TransferLiteralFromDevice( + Tensor* host_tensor, const Tensor& device_tensor, + const StatusCallback& done) const { const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - transfer_manager_->TransferLiteralFromDevice( - stream_->parent(), shaped_buffer)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " - << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - return errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } - return Status::OK(); + TensorReference ref(device_tensor); + transfer_manager_->TransferLiteralFromDevice( + stream_, shaped_buffer, + [=, &shaped_buffer]( + xla::StatusOr > literal_or) { + ref.Unref(); + done([&]() -> Status { + TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); + VLOG(1) << "Transfer from device as literal: " << literal->ToString() + << " " << shaped_buffer.ToString(); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + Status status; + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + status = errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return status; + }()); + }); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -121,17 +138,16 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, TensorShape shape = shape_representation_fn_(device_tensor->shape(), device_tensor->dtype()); + Status status; if (!xla_tensor->has_shaped_buffer()) { - Status s = xla_tensor->AllocateShapedBuffer( + status = xla_tensor->AllocateShapedBuffer( device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); - if (!s.ok()) { - done(s); - return; + if (!status.ok()) { + return done(status); } } - Status status; if (transfer_as_literal_) { Tensor reshaped_cpu_tensor; if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { @@ -184,7 +200,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status status; if (transfer_as_literal_) { - status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); + TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); + return; } else { stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. @@ -194,9 +211,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, "Failed to complete data transfer on stream %p: %s", stream_, block_status.error_message().c_str()); } + done(status); } - - done(status); return; } @@ -207,8 +223,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done) { - // TODO(phawkins): replace this code with an asynchronous implementation. - auto body = [&]() { + // Perform memory allocation now, and enqueue the device-to-device transfer. + Status status = [&]() -> Status { if (src_tensor.NumElements() == 0) { return Status::OK(); } @@ -223,21 +239,20 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, stream_->parent()->device_ordinal())); } - TF_RETURN_IF_ERROR( - xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( - [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { - const se::DeviceMemoryBase& from_buffer = - xla_src->shaped_buffer().buffers().element(index); - CHECK_EQ(buffer->size(), from_buffer.size()); - if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer, - buffer->size())) { - return errors::Internal("Device to device memcpy failed"); - } - return Status::OK(); - })); + auto from_iter = xla_src->shaped_buffer().buffers().begin(); + auto to_iter = xla_dst->shaped_buffer().buffers().begin(); + for (auto end_iter = xla_src->shaped_buffer().buffers().end(); + from_iter != end_iter; ++from_iter, ++to_iter) { + stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second, + to_iter->second.size()); + } return Status::OK(); - }; - done(body()); + }(); + if (!status.ok()) { + return done(status); + } else { + stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); + } } XlaDeviceContext::XlaDeviceContext( diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index ee346e5653bbf9f393df202572c2150b4989506f..c5c81d65fe0f4a2774aab9f742454467e052071e 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -64,8 +64,9 @@ class XlaTransferManager { private: Status TransferLiteralToDevice(const Tensor& host_tensor, Tensor* device_tensor) const; - Status TransferLiteralFromDevice(Tensor* host_tensor, - const Tensor& device_tensor) const; + void TransferLiteralFromDevice(Tensor* host_tensor, + const Tensor& device_tensor, + const StatusCallback& done) const; // Stream obtained from a Device, used to transfer tensors between // CPU and device. diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 11e45d2823da2b623bd3cd45f7147686b05fdb2f..a605335a94f8687e0af4566f912b38dca9b5ac26 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,9 +23,11 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" @@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel { .Device(DEVICE) \ .HostMemory("input") \ .HostMemory("output"), \ - LoopCondOp); + LoopCondOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \ + REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \ + .Device(DEVICE) \ + .HostMemory("size") \ + .HostMemory("handle"), \ + QueueSizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \ + QueueIsClosedOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); + +// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read +// and write the tensors they access in order to concatenate them into a batch. +// We would need either to call out to an XLA computation to perform the +// concatenation, or we would need to refactor those kernels so the splitting +// or merging is done in a separate operator that can be compiled. } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index d0c7a9365125708b2af43f87c7617d8d84050a61..5ceccc769fa2e95d4cf4d2b4ebd8dbf312ebdfd0 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -176,6 +176,21 @@ void XlaComputationLaunchContext::PopulateOutputs( } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + // If the on-host-shape isn't a tuple, create a new single-element tuple + // buffer with a nullptr root index table. This allows the code below to treat + // output as a tuple unconditionally. + if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) { + ShapedBuffer nontuple_buffer = output.release(); + ShapedBuffer buffer( + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}), + output.platform(), output.device_ordinal()); + buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(), + /*source_base_index=*/{}, + /*target_base_index=*/{0}); + output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -230,9 +245,14 @@ void XlaComputationLaunchContext::PopulateOutputs( Tensor* output_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); + } } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 98fab319d6f4fbf3159b6e8815baea262b882d2a..366822f0b74ca8afe1d203449357e19dc0242445 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -51,6 +51,16 @@ py_library( ], ) +py_library( + name = "test_utils", + testonly = 1, + srcs = ["test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//third_party/py/numpy", + ], +) + py_test( name = "xla_test_test", size = "small", @@ -247,6 +257,7 @@ tf_xla_py_test( srcs = ["conv2d_test.py"], shard_count = 10, deps = [ + ":test_utils", ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework", @@ -254,6 +265,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -359,6 +371,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "fifo_queue_test", + size = "medium", + srcs = ["fifo_queue_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", @@ -548,8 +574,11 @@ tf_xla_py_test( name = "random_ops_test", size = "small", srcs = ["random_ops_test.py"], - # TODO(b/31361304): enable RNG ops on GPU when parallelized. disabled_backends = [ + # TODO(b/110300529): RngNormal doesn't return values with the expected variance + "cpu", + "cpu_ondemand", + # TODO(b/31361304): enable RNG ops on GPU when parallelized. "gpu", ], deps = [ @@ -673,6 +702,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sparse_to_dense_op_test", + size = "small", + srcs = ["sparse_to_dense_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_ops", + ], +) + tf_xla_py_test( name = "stack_ops_test", size = "small", @@ -752,9 +794,10 @@ tf_xla_py_test( tf_xla_py_test( name = "fused_batchnorm_test", - size = "small", + size = "medium", srcs = ["fused_batchnorm_test.py"], deps = [ + ":test_utils", ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:math_ops", @@ -764,6 +807,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", ], ) @@ -839,6 +883,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sort_ops_test", + size = "small", + srcs = ["sort_ops_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 9a93b3216404d8ed21fd6c57757bec1730c119b4..d775850a80e9f83f7b2c9f1cf8997dd50e229635 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import adagrad -class AdagradOptimizerTest(XLATestCase): +class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 3215dc36e5b2d517aa951db1b0d41188185ef93a..03554d6933aca39b428c6af4be0c78e2c7ccb0c9 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -48,7 +48,7 @@ def adam_update_numpy(param, return param_t, m_t, v_t -class AdamOptimizerTest(XLATestCase): +class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 1e4dd32916c3a40282735fb8f75670b0e9ef0dc9..9cb3d0454608c37e669d5b4360bc39bf1bf7e68c 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class BinaryOpsTest(XLATestCase): +class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): @@ -226,6 +226,11 @@ class BinaryOpsTest(XLATestCase): np.array([0b1, 0b101, 0b1000], dtype=dtype), np.array([0b0, 0b101, 0b1001], dtype=dtype), expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_xor, + np.array([0b1, 0b111, 0b1100], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b1, 0b010, 0b0101], dtype=dtype)) lhs = np.array([0, 5, 3, 14], dtype=dtype) rhs = np.array([5, 0, 7, 11], dtype=dtype) @@ -1216,6 +1221,24 @@ class BinaryOpsTest(XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1, 3], [2, 4]], dtype=dtype)) + def testConjugateTranspose(self): + for dtype in self.complex_types: + self._testBinary( + array_ops.conjugate_transpose, + np.zeros(shape=[1, 0, 4], dtype=dtype), + np.array([1, 2, 0], dtype=np.int32), + expected=np.zeros(shape=[0, 4, 1], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([0, 1], dtype=np.int32), + expected=np.array([[1 + 1j, 2 - 2j], [3 + 3j, 4 - 4j]], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([1, 0], dtype=np.int32), + expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype)) + def testCross(self): for dtype in self.float_types: self._testBinary( diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index fde9759a1c209844caac99d5f303cd3e406e5370..ef4d5f6322b7ae79b051795b5af7e6f7f1e55550 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class BucketizationOpTest(XLATestCase): +class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 035cdea1786d39f3d21bb63be5c8ccffe1608bdf..a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -22,7 +22,7 @@ import collections import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest # TODO(srvasude): Merge this with # third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. -class CategoricalTest(XLATestCase): +class CategoricalTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def output_dtypes(self): diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 1a8989d7c2f617525c301f30fd899a01362310bf..d2867278af93812eae804b66a7a6b706f98fa600 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CholeskyOpTest(XLATestCase): +class CholeskyOpTest(xla_test.XLATestCase): # Cholesky defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/cholesky) diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 574f82fc717818334ac5d72ebef2191f1c18e669..e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" -class ClusteringTest(XLATestCase): +class ClusteringTest(xla_test.XLATestCase): def testAdd(self): val1 = np.array([4, 3, 2, 1], dtype=np.float32) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index f10973e19f1945515b776cf86349445ed7334629..d9ad4281477e87f79f2ecb52989ae86a5030d0cc 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ConcatTest(XLATestCase): +class ConcatTest(xla_test.XLATestCase): def testHStack(self): with self.test_session(): @@ -292,7 +292,7 @@ class ConcatTest(XLATestCase): array_ops.concat([scalar, scalar, scalar], dim) -class ConcatOffsetTest(XLATestCase): +class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: @@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase): self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) -class PackTest(XLATestCase): +class PackTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index 62577b70ce96e220d79978f01614b2d9a3647680..98d41ba7edd52eedbf035097a48a1ce2ac7d5e9e 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -22,9 +22,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import test_utils +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops @@ -32,7 +34,15 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class Conv2DTest(XLATestCase): +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), +) + + +class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -40,6 +50,8 @@ class Conv2DTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that tf.nn.conv2d produces the expected value. @@ -51,8 +63,12 @@ class Conv2DTest(XLATestCase): strides: Strides. dilations: RHS dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ + total_size_1 = np.prod(input_sizes) total_size_2 = np.prod(filter_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) @@ -62,6 +78,18 @@ class Conv2DTest(XLATestCase): dilations = [1, 1] dilations = [1] + dilations + [1] + # Convert between data formats. + expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src, + data_format_dst) + x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) @@ -71,12 +99,14 @@ class Conv2DTest(XLATestCase): t2, strides=strides, padding=padding, - data_format="NHWC", + data_format=data_format_dst, dilations=dilations) + value = sess.run(out, {t1: x1, t2: x2}) self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = np.reshape([ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0 @@ -86,9 +116,12 @@ class Conv2DTest(XLATestCase): filter_sizes=[1, 1, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0], [1, 1, 2, 3]) self._VerifyValues( @@ -96,9 +129,12 @@ class Conv2DTest(XLATestCase): filter_sizes=[2, 2, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter2x1Dilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter2x1Dilation(self, data_format): expected_output = np.array([[[[72], [82], [92]], [[112], [122], [132]]]]) self._VerifyValues( input_sizes=[1, 4, 4, 1], @@ -106,9 +142,12 @@ class Conv2DTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = np.reshape([ 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0 @@ -118,18 +157,24 @@ class Conv2DTest(XLATestCase): filter_sizes=[1, 2, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = np.reshape([2271.0, 2367.0, 2463.0], [1, 1, 1, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0], [1, 1, 2, 3]) self._VerifyValues( @@ -137,47 +182,61 @@ class Conv2DTest(XLATestCase): filter_sizes=[2, 2, 3, 3], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2DEmptyDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyDilation(self, data_format): self._VerifyValues( input_sizes=[0, 2, 3, 3], filter_sizes=[1, 1, 3, 3], strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0, 2, 3, 3])) - def testConv2D2x2FilterDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterDilation(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3])) - def testConv2D1x2FilterDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterDilation(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 2, 3, 3], strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.array([[[[231, 252, 273], [384, 423, 462]], [[690, 765, 840], [843, 936, 1029]]]])) - def testConv2DKernelSizeMatchesInputSizeDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeDilation(self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.reshape([108, 128], [1, 1, 1, 2])) -class Conv2DBackpropInputTest(XLATestCase): +class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -186,6 +245,8 @@ class Conv2DBackpropInputTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that gen_nn_ops.conv2d_backprop_input produces the expected output. @@ -198,8 +259,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides: Strides. dilations: Dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ + total_size_1 = np.prod(filter_sizes) total_size_2 = np.prod(out_backprop_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes) @@ -209,6 +274,23 @@ class Conv2DBackpropInputTest(XLATestCase): if dilations is not None: dilations = [1] + dilations + [1] + expected = np.reshape(expected, input_sizes) + + # Convert between data formats. + expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src, + data_format_dst) + x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats( + out_backprop_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + if dilations is not None: + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) @@ -220,12 +302,14 @@ class Conv2DBackpropInputTest(XLATestCase): strides=strides, dilations=dilations, padding=padding, - data_format="NHWC") + data_format=data_format_dst) + value = sess.run(out, {t1: x1, t2: x2}) self.assertAllEqual(input_sizes, value.shape) - self.assertAllClose(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = [ 5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127, 41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71, @@ -237,9 +321,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 4, 4, 2], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width5(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width5(self, data_format): expected_output = [1, 2, 0, 2, 4] self._VerifyValues( input_sizes=[1, 1, 5, 1], @@ -247,9 +334,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width6(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width6(self, data_format): expected_output = [1, 2, 0, 2, 4, 0] self._VerifyValues( input_sizes=[1, 1, 6, 1], @@ -257,9 +347,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width7(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width7(self, data_format): expected_output = [1, 2, 0, 2, 4, 0, 0] self._VerifyValues( input_sizes=[1, 1, 7, 1], @@ -267,9 +360,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterC1Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterC1Same(self, data_format): expected_output = [1, 4, 7, 7, 23, 33] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -277,9 +373,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = [ 14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604, 437, 482, 527 @@ -290,9 +389,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterSame(self, data_format): expected_output = [ 14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217, 1505, 1487, 1883, 2279 @@ -303,9 +405,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 3], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -313,9 +418,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 3, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterSame(self, data_format): expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -323,9 +431,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 3, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12] self._VerifyValues( input_sizes=[1, 3, 5, 1], @@ -333,9 +444,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 2, 1], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = [1, 2, 2, 3, 4, 6] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -343,9 +457,13 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 6, 1], filter_sizes=[2, 2, 1, 1], @@ -353,9 +471,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 4, 7, 10, 13, 10, 0, 0, 0, 0, 0, 0, 3, 10, 17, 24, 31, 20]) - def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -363,9 +484,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 0, 2, 3, 0, 4]) - def testConv2DEmptyBackpropInputDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyBackpropInputDilation1x2(self, data_format): self._VerifyValues( input_sizes=[0, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -373,9 +497,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0])) - def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self, data_format): # The GPU version of this test is not very stable. So adjusting the # error threshold to 1e-4. self._VerifyValues( @@ -385,12 +512,16 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[ 14, 32, 50, 68, 86, 104, 0, 0, 0, 0, 0, 0, 122, 140, 158, 176, 194, 212 ]) - def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], @@ -398,10 +529,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[5, 0, 11, 0, 0, 0, 17, 0, 23]) -class Conv2DBackpropFilterTest(XLATestCase): +class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -410,6 +543,8 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that gen_nn_ops.conv2d_backprop_filter produces the right output. @@ -422,6 +557,9 @@ class Conv2DBackpropFilterTest(XLATestCase): strides: Stride. dilations: Dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ @@ -434,6 +572,23 @@ class Conv2DBackpropFilterTest(XLATestCase): if dilations is not None: dilations = [1] + dilations + [1] + expected = np.reshape(expected, filter_sizes) + + # Convert between data formats. + x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src, + data_format_dst) + x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats( + out_backprop_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + if dilations is not None: + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) @@ -445,13 +600,14 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=strides, dilations=dilations, padding=padding, - data_format="NHWC") + data_format=data_format_dst) value = sess.run(tensor, {t1: x1, t2: x2}) self.assertAllEqual(filter_sizes, value.shape) - self.assertAllClose(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = [8056, 8432, 8312, 8704, 8568, 8976] self._VerifyValues( input_sizes=[1, 4, 4, 3], @@ -459,9 +615,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 4, 4, 2], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = [120, 141] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -469,9 +628,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 3, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterDepth1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterDepth1(self, data_format): expected_output = [5, 8, 14, 17] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -479,9 +641,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = [ 17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72, 62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87, @@ -493,9 +658,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width5(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width5(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 5, 1], @@ -503,9 +671,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width6(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width6(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 6, 1], @@ -513,9 +684,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width7(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width7(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 7, 1], @@ -523,9 +697,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3Filter(self, data_format): expected_output = [5, 8, 11] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -533,9 +710,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3FilterSame(self, data_format): expected_output = [20, 30, 20] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -543,9 +723,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 4, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3FilterSameOutbackprop2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3FilterSameOutbackprop2(self, data_format): expected_output = [7, 10, 3] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -553,9 +736,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterC1Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterC1Same(self, data_format): expected_output = [91, 58, 32, 17] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -563,9 +749,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = [92, 102, 112] self._VerifyValues( input_sizes=[1, 3, 5, 1], @@ -573,9 +762,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 2, 2, 1], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = [7, 2, 16, 5] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -583,9 +775,13 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 6, 1], filter_sizes=[2, 2, 1, 1], @@ -593,9 +789,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[55, 70, 235, 250]) - def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -603,9 +802,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 3, 4, 6]) - def testConv2DEmptyBackpropFilterDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyBackpropFilterDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 0], @@ -613,9 +815,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0])) - def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self, data_format): self._VerifyValues( input_sizes=[1, 3, 4, 3], filter_sizes=[2, 2, 3, 3], @@ -623,13 +828,17 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[ 17, 22, 27, 22, 29, 36, 27, 36, 45, 47, 64, 81, 52, 71, 90, 57, 78, 99, 137, 190, 243, 142, 197, 252, 147, 204, 261, 167, 232, 297, 172, 239, 306, 177, 246, 315 ]) - def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], @@ -637,6 +846,8 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 2, 3, 6, 7, 14, 9, 18]) diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 3bebf46511cbc471d3fbbbe92d28511fcc717387..31ee41f04f27d387415e9fa2c4fa70b33cab7b04 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest # Test cloned from # tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py -class Conv3DBackpropFilterV2GradTest(XLATestCase): +class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): with self.test_session(), self.test_scope(): @@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase): # Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py -class Conv3DTransposeTest(XLATestCase): +class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 03d96a2cd8ab22a472a67f092e36224820405fa8..98dc73e189f99b7b811487756659d89dacb97d8a 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -114,7 +114,7 @@ def CheckGradConfigsToTest(): yield i, f, o, s, p -class DepthwiseConv2DTest(XLATestCase): +class DepthwiseConv2DTest(xla_test.XLATestCase): # This is testing that depthwise_conv2d and depthwise_conv2d_native # produce the same results. It also tests that NCHW and NWHC diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 6a46d2ec3e7aee3a4ecfbf1ab9f622d8eb659e3c..154e36b10e6da409606ae6022aaf53e34c8e37cc 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class DynamicUpdateSliceOpsTest(XLATestCase): +class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index c109c27abe2f145685f83251e1d21ec8ddad563a..edd78153b56bb5bf1c268936fb82a60581389733 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.platform import googletest -class DynamicStitchTest(XLATestCase): +class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index a4154ad1e846f8241a2ab6598da36ccb6b3b653e..3524666499cbb2ef3eae2bb3b314dda0a9be64c8 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -40,7 +40,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import adam -class EagerTest(XLATestCase): +class EagerTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): @@ -49,6 +49,21 @@ class EagerTest(XLATestCase): product = three * five self.assertAllEqual(15, product) + def testGradientTape(self): + with self.test_scope(): + + x = constant_op.constant(1.0) + y = constant_op.constant(10.0) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(x) + tape.watch(y) + a = x + y + x * y + da_dx = tape.gradient(a, x) + da_dy = tape.gradient(a, y) + + self.assertEqual(11.0, da_dx.numpy()) + self.assertEqual(2.0, da_dy.numpy()) + def testExecuteListOutputLen0(self): with self.test_scope(): empty = constant_op.constant([], dtype=dtypes.float32) @@ -271,11 +286,11 @@ class EagerTest(XLATestCase): [2.0, 2.0]], embedding_matrix.numpy()) -class EagerFunctionTest(XLATestCase): +class EagerFunctionTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): - matmul = function.defun(math_ops.matmul, compiled=True) + matmul = function.defun(math_ops.matmul) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq = matmul(t, t, transpose_a=True) self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) @@ -297,7 +312,7 @@ class EagerFunctionTest(XLATestCase): def model(x): x = conv(x) return pool(x) - model = function.defun(model, compiled=True) + model = function.defun(model) x = array_ops.ones([1, 4, 4, 1]) y = model(x) @@ -307,7 +322,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v = resource_variable_ops.ResourceVariable(1.0) - @function.defun(compiled=True) + @function.defun def f(): return v.read_value() @@ -322,7 +337,7 @@ class EagerFunctionTest(XLATestCase): v.assign_add(1.0) return v - f = function.defun(f, compiled=True) + f = function.defun(f) var = f(v) self.assertEqual(2.0, var.numpy()) @@ -350,7 +365,7 @@ class EagerFunctionTest(XLATestCase): d = r2 * v2 return a, b, c, d - foo = function.defun(foo, compiled=True) + foo = function.defun(foo) c1 = [0, 0] c2 = array_ops.ones([2], dtype=dtypes.int32) @@ -372,7 +387,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v0 = resource_variable_ops.ResourceVariable(5.0) - @function.defun(compiled=True) + @function.defun def f(x): x = v0 * v0 * x return x @@ -385,8 +400,26 @@ class EagerFunctionTest(XLATestCase): self.assertEqual(75, y.numpy()) self.assertEqual(30, dy.numpy()) + def testSliceInDefun(self): + with self.test_scope(): -class ExcessivePaddingTest(XLATestCase): + @function.defun(compiled=True) + def f(x, y): + return x[0::2, y:, ...] + + x = array_ops.ones([2, 3, 4]) + y = array_ops.ones([], dtype=dtypes.int32) + with backprop.GradientTape() as tape: + tape.watch(x) + tape.watch(y) + z = f(x, y) + dz = tape.gradient(z, x) + + self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) + self.assertAllEqual((2, 3, 4), dz.shape.as_list()) + + +class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. Tensors that would normally be excessively padded when written @@ -417,7 +450,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionInput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return math_ops.reduce_sum(x, axis=2) @@ -428,7 +461,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionOutput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return x * constant_op.constant(100 * [[[10.0, 2.0]]]) diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 0361702e7af778176daed941d64e61198090daf2..5529fdbb090315e1d7f47589777d8a538c90db2b 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ExtractImagePatches(XLATestCase): +class ExtractImagePatches(xla_test.XLATestCase): """Functional tests for ExtractImagePatches op.""" def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index dfe9400ef0f55ca011d4e23ba5d735899ca2e054..c48ab178bf53558084fb500b2811c6f0b77a7943 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -17,14 +17,14 @@ from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import googletest -class FakeQuantWithMinMaxArgsTest(XLATestCase): +class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgs operation.""" # 8 bits, wide range. @@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): +class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" # 8 bits, wide range. @@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsTest(XLATestCase): +class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVars operation.""" # 8 bits, wide range. @@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): +class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" # 8 bits, wide range. diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index afb5fa4bb4fefe5bc2ecded826143ffc83c2b559..c64ea249ecb97991952a960a6d16e1bb3be35b17 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -23,10 +23,11 @@ import itertools import numpy as np import scipy.signal as sps -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import spectral_ops from tensorflow.python.platform import googletest @@ -57,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) -class FFTTest(XLATestCase): +class FFTTest(xla_test.XLATestCase): def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, tf_method): @@ -97,8 +98,11 @@ class FFTTest(XLATestCase): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) out = signal.stft(ph, ws, hs) + grad = gradients_impl.gradients(out, ph, + grad_ys=array_ops.ones_like(out)) - value = sess.run(out, {ph: data}) + # For gradients, we simply verify that they compile & execute. + value, _ = sess.run([out, grad], {ph: data}) self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) def testFFT(self): diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0f64cc87cde77fbbef6c4e570879e992bc34bafa --- /dev/null +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -0,0 +1,201 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import test + + +class FIFOQueueTest(xla_test.XLATestCase): + + def testEnqueue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + enqueue_op.run() + + def testEnqueueWithShape(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) + enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) + enqueue_correct_op.run() + with self.assertRaises(ValueError): + q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) + self.assertEqual(1, q.size().eval()) + + def testMultipleDequeues(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue([1])) + self.evaluate(q.enqueue([2])) + self.evaluate(q.enqueue([3])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + def testQueuesDontShare(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) + + def testEnqueueDictWithoutNames(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + with self.assertRaisesRegexp(ValueError, "must have names"): + q.enqueue({"a": 12.0}) + + def testParallelEnqueue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Run one producer thread for each element in elems. + def enqueue(enqueue_op): + sess.run(enqueue_op) + + threads = [ + self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Dequeue every element using a single thread. + results = [] + for _ in xrange(len(elems)): + results.append(dequeued_t.eval()) + self.assertItemsEqual(elems, results) + + def testParallelDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Enqueue every element using a single thread. + for enqueue_op in enqueue_ops: + enqueue_op.run() + + # Run one consumer thread for each element in elems. + results = [] + + def dequeue(): + results.append(sess.run(dequeued_t)) + + threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, results) + + def testDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + vals = dequeued_t.eval() + self.assertEqual([elems[i]], vals) + + def testEnqueueAndBlockingDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + def enqueue(): + # The enqueue_ops should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + for enqueue_op in enqueue_ops: + sess.run(enqueue_op) + + results = [] + + def dequeue(): + for _ in xrange(len(elems)): + results.append(sess.run(dequeued_t)) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + for elem, result in zip(elems, results): + self.assertEqual([elem], result) + + def testMultiEnqueueAndDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) + elems = [(5, 10.0), (10, 20.0), (15, 30.0)] + enqueue_ops = [q.enqueue((x, y)) for x, y in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + x_val, y_val = sess.run(dequeued_t) + x, y = elems[i] + self.assertEqual([x], x_val) + self.assertEqual([y], y_val) + + def testQueueSizeEmpty(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + self.assertEqual([0], q.size().eval()) + + def testQueueSizeAfterEnqueueAndDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + dequeued_t = q.dequeue() + size = q.size() + self.assertEqual([], size.get_shape()) + + enqueue_op.run() + self.assertEqual(1, size.eval()) + dequeued_t.op.run() + self.assertEqual(0, size.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 8e6407dffdac3adbcda8cbca2109ef9196defa8c..1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent -class FtrlOptimizerTest(XLATestCase): +class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 8a3f4b0bdc7a61d6cfa2ba7474ce8579e293a5c7..04fba444460e714ce96205361ac02ed492206b04 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class FunctionTest(XLATestCase): +class FunctionTest(xla_test.XLATestCase): def testFunction(self): """Executes a simple TensorFlow function.""" diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index a80d69fa5f5099b8a8b67df0da9c92b957e9d194..132e42ac7a28d0769b0de12ea0cee6eae752b245 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import test_utils +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker @@ -28,7 +30,7 @@ from tensorflow.python.ops import nn from tensorflow.python.platform import test -class FusedBatchNormTest(XLATestCase): +class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): def _reference_training(self, x, scale, offset, epsilon, data_format): if data_format != "NHWC": @@ -63,24 +65,36 @@ class FusedBatchNormTest(XLATestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - def testInference(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) - offset_val = np.random.random_sample(scale_shape).astype(np.float32) - data_format = "NHWC" + epsilon = 0.001 + data_format_src = "NHWC" + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format_src) + with self.test_session() as sess, self.test_scope(): # To avoid constant folding - t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + y_ref_converted = test_utils.ConvertBetweenDataFormats( + y_ref, data_format_src, data_format) + + t_val = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") - epsilon = 0.001 - y_ref, mean_ref, var_ref = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format) y, mean, variance = nn.fused_batch_norm( t_val, scale, @@ -91,31 +105,39 @@ class FusedBatchNormTest(XLATestCase): data_format=data_format, is_training=False) - y_val, _, _ = sess.run( - [y, mean, - variance], {t_val: x_val, - scale: scale_val, - offset: offset_val}) - self.assertAllClose(y_val, y_ref, atol=1e-3) + y_val, _, _ = sess.run([y, mean, variance], { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val + }) + self.assertAllClose(y_val, y_ref_converted, atol=1e-3) - def _testLearning(self, use_gradient_checker): + def _testLearning(self, use_gradient_checker, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) - offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) - data_format = "NHWC" + epsilon = 0.001 + data_format_src = "NHWC" + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format_src) + with self.test_session() as sess, self.test_scope(): # To avoid constant folding - t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + y_ref_converted = test_utils.ConvertBetweenDataFormats( + y_ref, data_format_src, data_format) + + t_val = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") - epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, scale, @@ -129,33 +151,50 @@ class FusedBatchNormTest(XLATestCase): if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, - x_shape, + x_val_converted.shape, y, - x_shape, + x_val_converted.shape, extra_feed_dict={ - t_val: x_val, + t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertLess(err, 1e-3) - y_val, mean_val, var_val = sess.run( - [y, mean, var], {t_val: x_val, - scale: scale_val, - offset: offset_val}) - y_ref, mean_ref, var_ref = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format) + y_val, mean_val, var_val = sess.run([y, mean, var], { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val + }) self.assertAllClose(mean_val, mean_ref, atol=1e-3) - self.assertAllClose(y_val, y_ref, atol=1e-3) + self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - def testLearning(self): - self._testLearning(False) + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testLearning(self, data_format): + self._testLearning(False, data_format) - def testLearningWithGradientChecker(self): - self._testLearning(True) + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testLearningWithGradientChecker(self, data_format): + self._testLearning(True, data_format) - def testGradientTraining(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -167,33 +206,48 @@ class FusedBatchNormTest(XLATestCase): mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 + data_format_src = "NHWC" + grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( + x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) with self.test_session() as sess, self.test_scope(): - grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") - x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + grad_val_converted = test_utils.ConvertBetweenDataFormats( + grad_val, data_format_src, data_format) + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + grad_x_ref_converted = test_utils.ConvertBetweenDataFormats( + grad_x_ref, data_format_src, data_format) + + grad = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="grad") + x = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=True) + grad, x, scale, mean, var, data_format=data_format, is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val }) - grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( - x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC") - - self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2) self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - def testGradientInference(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -204,33 +258,47 @@ class FusedBatchNormTest(XLATestCase): scale_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) + data_format_src = "NHWC" with self.test_session() as sess, self.test_scope(): - grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") - x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + grad_val_converted = test_utils.ConvertBetweenDataFormats( + grad_val, data_format_src, data_format) + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + + grad = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="grad") + x = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") with self.test_scope(): out = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad, + x, + scale, + mean, + var, + data_format=data_format, + is_training=False) grad_x, grad_scale, grad_offset, _, _ = out ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad, x, scale, mean, var, data_format=data_format, is_training=False) grad_x_val, grad_scale_val, grad_offset_val, = sess.run( [grad_x, grad_scale, grad_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val }) grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( [ref_x, ref_scale, ref_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 9378b1db7245c0da3e8298e7dcd972491616b0cd..23b0aed34fb460f50c241e5a920cb4f6f613b947 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GatherNdTest(XLATestCase): +class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): with self.test_session(): diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 1a8c4519118f69ce51ca9a5eb95a9d706c7766cc..e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual( [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) + def testGatherPrecision(self): + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], + [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) + indices = np.array([1, 2, 3, 1]) + dtype = dtypes.float32 + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices_tf = constant_op.constant(indices) + gather_t = array_ops.gather(params, indices_tf) + gather_val = session.run(gather_t, feed_dict={params: params_np}) + np_val = params_np[indices] + self.assertAllEqual(np_val, gather_val) + class GatherBenchmark(test.Benchmark): """Microbenchmarks for the gather op.""" diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 7cf953ef25ef5daf8a6d4fc9985ed8dbfb2081e5..8b01ef96db3e8ab58850df234c2e05b764be52ba 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -25,7 +25,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def GenerateNumpyRandomRGB(shape): return np.random.randint(0, 256, shape) / 256. -class RGBToHSVTest(XLATestCase): +class RGBToHSVTest(xla_test.XLATestCase): def testBatch(self): # Build an arbitrary RGB image @@ -104,7 +104,7 @@ class RGBToHSVTest(XLATestCase): self.assertAllCloseAccordingToType(hsv_tf, hsv_np) -class AdjustContrastTest(XLATestCase): +class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): with self.test_session(): @@ -168,7 +168,7 @@ class AdjustContrastTest(XLATestCase): self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) -class AdjustHueTest(XLATestCase): +class AdjustHueTest(xla_test.XLATestCase): def testAdjustNegativeHue(self): x_shape = [2, 2, 3] @@ -303,7 +303,7 @@ class AdjustHueTest(XLATestCase): self._adjustHueTf(x_np, delta_h) -class AdjustSaturationTest(XLATestCase): +class AdjustSaturationTest(xla_test.XLATestCase): def _adjust_saturation(self, image, saturation_factor): image = ops.convert_to_tensor(image, name="image") @@ -403,7 +403,7 @@ class AdjustSaturationTest(XLATestCase): self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) -class ResizeBilinearTest(XLATestCase): +class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 69bd8f7230d4394c45764d02a88fb0ec097c5756..253b45902fba2df64e5234f135b373cd2a0a7e2a 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -22,7 +22,7 @@ import copy import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" # Local response normalization tests. The forward tests are copied from # tensorflow/python/kernel_tests/lrn_op_test.py -class LRNTest(XLATestCase): +class LRNTest(xla_test.XLATestCase): def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 29394f9ea5139b30f88f53de0469b27e37d79195..0d9f99f8a6803ecae5f9233518a1768109161ac0 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -19,14 +19,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase): def _testMatrixBandPart(self, dtype, shape): with self.test_session(): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 5819b2bf2b55b9213a039c0ba82dd0bf1c738b00..2bb8a97bdaf5836a05501ab9754433e29ae34675 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -22,7 +22,7 @@ import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -35,7 +35,7 @@ def MakePlaceholder(x): return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) -class MatrixTriangularSolveOpTest(XLATestCase): +class MatrixTriangularSolveOpTest(xla_test.XLATestCase): # MatrixTriangularSolve defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index af9394e7d7dc9cf7dd009420ff9c845aec8785bd..c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import momentum as momentum_lib -class MomentumOptimizerTest(XLATestCase): +class MomentumOptimizerTest(xla_test.XLATestCase): def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): var += accum * lr * momentum diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index e4843b169b943b63346b783ddc50039030988ca5..da08225e9fc0d5a8ec21ee9961c4758fa38628b4 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -22,14 +22,14 @@ import unittest import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class NAryOpsTest(XLATestCase): +class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 6f588d8ab562cb24f33c4c2987df22264aede027..2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import googletest -class NullaryOpsTest(XLATestCase): +class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 5e6d1313bd0336eba71fcf3658d949bd3342ae11..a75d99189b5b673261c9e48f1c5998ea0c575594 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class PlaceholderTest(XLATestCase): +class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): with self.test_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index 4eed903963a34a253ea5c409782d9a89a97a4fdf..17f860db61aeda98326a6820771d67ee948b6dda 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): padding=padding) -class Pooling3DTest(XLATestCase): +class Pooling3DTest(xla_test.XLATestCase): def _VerifyValues(self, pool_func, input_sizes, window, strides, padding, expected): @@ -187,8 +187,14 @@ class Pooling3DTest(XLATestCase): padding="VALID", expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5]) - def _VerifyGradient(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding): + def _VerifyGradient(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + pool_grad_grad_func=None): """Verifies the output values of the pooling gradient function. Args: @@ -198,6 +204,7 @@ class Pooling3DTest(XLATestCase): ksize: The kernel size dimensions strides: The stride dimensions padding: Padding type. + pool_grad_grad_func: Second-order gradient function, if available. """ ksize = [1] + ksize + [1] strides = [1] + strides + [1] @@ -218,6 +225,8 @@ class Pooling3DTest(XLATestCase): output_gradient_vals = np.arange( 1, output_vals.size + 1, dtype=np.float32) output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) + output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) # Use the Tensorflow CPU pooling gradient to compute the expected input # gradients. @@ -236,6 +245,22 @@ class Pooling3DTest(XLATestCase): {inputs: x, output_gradients: output_gradient_vals}) + output_grad_gradients = array_ops.placeholder( + dtypes.float32, shape=expected_input_gradient_vals.shape) + if pool_grad_grad_func is not None: + expected_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NDHWC") + expected_grad_gradients_vals = sess.run(expected_grad_gradients, { + inputs: x, + output_grad_gradients: output_grad_grad_vals + }) + # Run the gradient op on the XLA device with self.test_scope(): outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) @@ -246,6 +271,16 @@ class Pooling3DTest(XLATestCase): ksize=ksize, strides=strides, padding=padding) + if pool_grad_grad_func is not None: + actual_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NDHWC") + actual = sess.run(actual_input_gradients, { inputs: x, outputs: output_vals, @@ -260,6 +295,22 @@ class Pooling3DTest(XLATestCase): atol=1e-6) self.assertShapeEqual(actual, inputs) + if pool_grad_grad_func is not None: + actual_grad_gradients_vals = sess.run( + actual_grad_gradients, { + inputs: x, + outputs: output_vals, + output_grad_gradients: output_grad_grad_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_grad_gradients_vals, + actual_grad_gradients_vals, + rtol=1e-4, + atol=1e-6) + self.assertShapeEqual(actual_grad_gradients_vals, outputs) + def testMaxPoolGradValidPadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -267,7 +318,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[1, 3, 3, 3, 1], ksize=[1, 1, 1], strides=[1, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradValidPadding2_1_6_3d(self): self._VerifyGradient( @@ -276,9 +328,13 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 3, 6, 3], ksize=[2, 2, 2], strides=[1, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradValidPadding2_1_7_3d(self): + # TODO(b/73062247): the bfloat16 implementation of MaxPool3DGradGrad does + # not have enough precision for this test case to pass if + # pool_grad_grad_func is passed. self._VerifyGradient( nn_ops.max_pool3d, gen_nn_ops.max_pool3d_grad, @@ -294,7 +350,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 2, 2, 2, 3], ksize=[2, 2, 2], strides=[2, 2, 2], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding1_1_3d(self): self._VerifyGradient( @@ -303,7 +360,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 2, 4, 1], ksize=[1, 1, 1], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding2_1_3d(self): self._VerifyGradient( @@ -312,7 +370,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 2, 4, 1], ksize=[2, 2, 2], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding2_2_3d(self): self._VerifyGradient( @@ -321,7 +380,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 5, 2, 4, 3], ksize=[2, 2, 2], strides=[2, 2, 2], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding3_1_3d(self): self._VerifyGradient( @@ -330,7 +390,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[1, 3, 3, 7, 1], ksize=[3, 3, 3], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testAvgPoolGradValidPadding1_1_3d(self): self._VerifyGradient( diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index fe270af3d636c0824621f36360ce9e7d14d8fc91..9fc94752ea660f7fb8b2c792180f01485ad04419 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -69,7 +69,7 @@ def GetTestConfigs(): return test_configs -class PoolingTest(XLATestCase): +class PoolingTest(xla_test.XLATestCase): def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, data_format, expected): @@ -288,7 +288,7 @@ class PoolingTest(XLATestCase): expected=expected_output) -class PoolGradTest(XLATestCase): +class PoolGradTest(xla_test.XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index f13dff96203b5480480c2a2fc9ac38ca78b7f78a..b880b2a3fea3ee72af96396bc2d61b2887e6e9b8 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -18,17 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import googletest -class RandomOpsTest(XLATestCase): +class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): @@ -87,15 +90,52 @@ class RandomOpsTest(XLATestCase): self._testRngIsNotConstant(rng, dtypes.float32) def testTruncatedNormalIsInRange(self): - count = 10000 + count = 10000000 # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: with self.test_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) y = sess.run(x) - self.assertTrue((y >= -2).sum() == count) - self.assertTrue((y <= 2).sum() == count) + + def normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x, sess=sess): + return sess.run(special_math.ndtri(x)) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = normal_cdf(beta) - normal_cdf(alpha) + + self.assertTrue((y >= a).sum() == count) + self.assertTrue((y <= b).sum() == count) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + actual_mean = np.mean(y) + self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + + expected_median = mu + probit( + (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + self.assertAllClose(actual_median, expected_median, atol=8e-4) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) def testShuffle1d(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 7420724bdbeab63b39542ada59328621febad895..cea2ec816f85e88b11e6e80c91c14fca9015f45c 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -22,7 +22,7 @@ import functools import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(XLATestCase): +class ReduceOpsTest(xla_test.XLATestCase): def _testReduction(self, tf_reduce_fn, @@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) -class ReduceOpPrecisionTest(XLATestCase): +class ReduceOpPrecisionTest(xla_test.XLATestCase): def _testReduceSum(self, expected_result, diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index e78a63465b80644d8810d9fa7433653bc4639fed..c69b6837b0f88ced844faf3713a29a1c14c8790d 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReduceWindowTest(XLATestCase): +class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 18fabca28c9817fc8517595fa1694a18399f54b0..d01c676e7c2fe705344f26818350c46c30451c67 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -21,14 +21,14 @@ from __future__ import print_function import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReverseOpsTest(XLATestCase): +class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 1a5d05094e53cfecd9476d7d87f023e8a02d7458..ccfa63001653537c4d1b7140e3d745c126f9034b 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ReverseSequenceTest(XLATestCase): +class ReverseSequenceTest(xla_test.XLATestCase): def _testReverseSequence(self, x, diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ecdce4f052bbe3eeae8697c02c891105103f4f69..9489fded32a7b6aada0543721a8bfe5f2d74575e 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import rmsprop -class RmspropTest(XLATestCase): +class RmspropTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 3260e63b23226d736a7ddc0f21a94a8c791e0442..4292352e76ebcef7dbf41df7b857d2604a468117 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse): return x -class CumsumTest(XLATestCase): +class CumsumTest(xla_test.XLATestCase): valid_dtypes = [np.float32] @@ -147,7 +147,7 @@ class CumsumTest(XLATestCase): math_ops.cumsum(input_tensor, [0]).eval() -class CumprodTest(XLATestCase): +class CumprodTest(xla_test.XLATestCase): valid_dtypes = [np.float32] diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 638946e234daf28dc4a34e6c33fc0f78b8e8699b..f606f88545d0b6f0b52cee9b93083a6bd91169bc 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -22,7 +22,7 @@ import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape): return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) -class ScatterNdTest(XLATestCase): +class ScatterNdTest(xla_test.XLATestCase): def _VariableRankTest(self, np_scatter, diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 4a9c0e7471f9cdb2a47b54705495d2dda9748890..772c20fd424577c3e06eeae409f424b77b52aa8a 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -21,26 +21,40 @@ from __future__ import print_function import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class SegmentReductionOpsTest(XLATestCase): +class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" - def UnsortedSegmentSum(self, data, indices, num_segments): + def _segmentReduction(self, op, data, indices, num_segments): with self.test_session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) else: i = array_ops.placeholder(indices.dtype, shape=indices.shape) - return sess.run( - math_ops.unsorted_segment_sum(d, i, num_segments), - {d: data, - i: indices}) + return sess.run(op(d, i, num_segments), {d: data, i: indices}) + + def _unsortedSegmentSum(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices, + num_segments) + + def _unsortedSegmentProd(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices, + num_segments) + + def _unsortedSegmentMin(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_min, data, indices, + num_segments) + + def _unsortedSegmentMax(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_max, data, indices, + num_segments) def testUnsortedSegmentSum0DIndices1DData(self): for dtype in self.numeric_types: @@ -49,14 +63,14 @@ class SegmentReductionOpsTest(XLATestCase): [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 0]], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) def testUnsortedSegmentSum1DIndices1DData(self): for dtype in self.numeric_types: self.assertAllClose( np.array([1, 3, 2, 9], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) @@ -64,7 +78,7 @@ class SegmentReductionOpsTest(XLATestCase): for dtype in self.numeric_types: self.assertAllClose( np.array([6, 3, 0, 6], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) @@ -76,7 +90,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) num_segments = 10 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], @@ -92,7 +106,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) num_segments = 4 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], @@ -102,30 +116,30 @@ class SegmentReductionOpsTest(XLATestCase): def testUnsortedSegmentSum2DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) num_segments = 8 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( - [[210, 211, 212], [110, 111, 112], [310, 311, 312], - [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301, - 302], [0, 0, 0]], + [[210, 211, 212], [110, 111, 112], [310, 311, 312], [ + 100, 102, 104 + ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]], dtype=dtype), y) def testUnsortedSegmentSum1DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([3, 0, 2, 5], dtype=np.int32) num_segments = 6 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], @@ -138,10 +152,40 @@ class SegmentReductionOpsTest(XLATestCase): data = np.ones((4, 8, 7), dtype=dtype) indices = np.ones((3, 2), dtype=np.int32) num_segments = 4 - self.assertRaises(ValueError, - functools.partial(self.UnsortedSegmentSum, data, - indices, num_segments)) + self.assertRaises( + ValueError, + functools.partial(self._segmentReduction, + math_ops.unsorted_segment_sum, data, indices, + num_segments)) + + def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self): + """Tests for min, max, and prod ops. + + These share most of their implementation with sum, so we only test basic + functionality. + """ + for dtype in self.numeric_types: + self.assertAllClose( + np.array([8, 3, 1, 0], dtype=dtype), + self._unsortedSegmentProd( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + + for dtype in self.int_types | self.float_types: + minval = dtypes.as_dtype(dtype).min + maxval = dtypes.as_dtype(dtype).max + + self.assertAllClose( + np.array([2, 3, maxval, 0], dtype=dtype), + self._unsortedSegmentMin( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + self.assertAllClose( + np.array([4, 3, minval, 6], dtype=dtype), + self._unsortedSegmentMax( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) -if __name__ == '__main__': +if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 305ca0c6b78d3ef985deb38816f9388e7983906b..6c4890565d2083a9493abc59bd563c4dd9fdb186 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class SliceTest(XLATestCase): +class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: @@ -110,7 +110,7 @@ class SliceTest(XLATestCase): self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result) -class StridedSliceTest(XLATestCase): +class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae579abda9854079ee491a7254eb4d09183594a --- /dev/null +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -0,0 +1,131 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sorting operators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaSortOpTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + if isinstance(output, ops.Tensor): + output = [output] + + results = session.run(output, feeds) + for result, v in zip(results, expected): + self.assertAllClose(v, result, rtol=1e-3) + + def testSort(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=dtype) + np.random.shuffle(x) + self._assertOpOutputMatchesExpected( + xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + + def testTopK(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 in self.numeric_types: + for x in [np.arange(20)]: + np.random.shuffle(x) + for k in [0, 1, 2, 10, 20]: + indices = x.argsort()[::-1][:k] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(bfloat16)], + expected=[x[indices].astype(bfloat16), indices]) + + def testTopKZeros(self): + """Tests that positive and negative zeros sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=4) + results = sess.run( + topk, + {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) + self.assertAllEqual( + np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) + self.assertEqual(list([3, 0, 1, 2]), list(results[1])) + + def testTopKInfinities(self): + """Tests that positive and negative infinity sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: np.array( + [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) + }) + self.assertAllEqual( + np.array( + [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=bfloat16), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index f37c34156f96761632247be4bc1b62fca54f666e..c685bc548f9f6f8f7723c6f94dfd45f5420b4a67 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings): return permuted_reshaped_padded.reshape(output_shape) -class SpaceToBatchTest(XLATestCase): +class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): @@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase): self._testOne(x_np, block_size, x_out) -class SpaceToBatchNDTest(XLATestCase): +class SpaceToBatchNDTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" def _testPad(self, inputs, block_shape, paddings, outputs): diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3db8101c4bfbb1b53c7318a36519612984d6f179 --- /dev/null +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -0,0 +1,118 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.kernels.sparse_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +def _SparseToDense(sparse_indices, + output_size, + sparse_values, + default_value, + validate_indices=True): + feed_sparse_indices = array_ops.placeholder(dtypes.int32) + feed_dict = {feed_sparse_indices: sparse_indices} + return sparse_ops.sparse_to_dense( + feed_sparse_indices, + output_size, + sparse_values, + default_value=default_value, + validate_indices=validate_indices).eval(feed_dict=feed_dict) + + +class SparseToDenseTest(xla_test.XLATestCase): + + def testInt(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1, 0) + np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testFloat(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) + np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) + self.assertAllClose(np_ans, tf_ans) + + def testSetValue(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) + np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testSetSingleValue(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1, -1) + np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def test2d(self): + # pylint: disable=bad-whitespace + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) + np_ans = np.array([[-1, -1, -1, -1], + [-1, -1, -1, 1], + [ 1, -1, -1, -1]]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testZeroDefault(self): + with self.test_session(): + x = sparse_ops.sparse_to_dense(2, [4], 7).eval() + self.assertAllEqual(x, [0, 0, 7, 0]) + + def test3d(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) + np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 + np_ans[1, 3, 0] = 1 + np_ans[2, 0, 1] = 1 + self.assertAllClose(np_ans, tf_ans) + + def testBadShape(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): + _SparseToDense([1, 3], [[5], [3]], 1, -1) + + def testBadValue(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError( + r"sparse_values has incorrect shape \[2,1\], " + r"should be \[\] or \[2\]"): + _SparseToDense([1, 3], [5], [[5], [3]], -1) + + def testBadNumValues(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError( + r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): + _SparseToDense([1, 3], [5], [1, 2, 3], -1) + + def testBadDefault(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError("default_value should be a scalar"): + _SparseToDense([1, 3], [5], [1, 2], [0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 94342f9567ca71274609e63b0482d55637c98d51..b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.platform import test -class StackOpTest(XLATestCase): +class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index b6f8390a45d43bf7666b90e14cc6ff2f3f61947e..d162675ef840131485128414b4a29e3cd89c8761 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -22,14 +22,15 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test -class StatelessRandomOpsTest(XLATestCase): +class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): @@ -122,6 +123,56 @@ class StatelessRandomOpsTest(XLATestCase): # so to avoid flakiness the seed is fixed. self.assertTrue(self._anderson_darling(y) < 2.492) + def testTruncatedNormalIsInRange(self): + # TODO(b/34339814): implement inverse erf support for non-F32 types. + for dtype in [dtypes.float32]: + with self.test_session() as sess, self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 10000000 + x = stateless.stateless_truncated_normal( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + + def normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x, sess=sess): + return sess.run(special_math.ndtri(x)) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = normal_cdf(beta) - normal_cdf(alpha) + + self.assertTrue((y >= a).sum() == n) + self.assertTrue((y <= b).sum() == n) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + actual_mean = np.mean(y) + self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + + expected_median = mu + probit( + (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + self.assertAllClose(actual_median, expected_median, atol=8e-4) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + self.assertAllClose(actual_variance, expected_variance, rtol=1e-3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index ef047005b60bd156a677050368ef67ae030d6c3a..effa5a59fee7dda543b2c409dfaa27a972a55808 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class TernaryOpsTest(XLATestCase): +class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6abde18ea91f16d153a154b94effab037a911c6c --- /dev/null +++ b/tensorflow/compiler/tests/test_utils.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for helping test ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): + """Converts 4D tensor between data formats.""" + + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + if data_format_src not in valid_data_formats: + raise ValueError("data_format_src must be of %s, got %s." % + (valid_data_formats, data_format_src)) + if data_format_dst not in valid_data_formats: + raise ValueError("data_format_dst must be of %s, got %s." % + (valid_data_formats, data_format_dst)) + if len(x.shape) != 4: + raise ValueError("x must be 4D, got shape %s." % x.shape) + + if data_format_src == data_format_dst: + return x + + dim_map = {d: i for i, d in enumerate(data_format_src)} + transpose_dims = [dim_map[d] for d in data_format_dst] + return np.transpose(x, transpose_dims) + + +def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): + """Get new shape for converting between data formats.""" + + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + if data_format_src not in valid_data_formats: + raise ValueError("data_format_src must be of %s, got %s." % + (valid_data_formats, data_format_src)) + if data_format_dst not in valid_data_formats: + raise ValueError("data_format_dst must be of %s, got %s." % + (valid_data_formats, data_format_dst)) + if len(dims) != 4: + raise ValueError("dims must be of length 4, got %s." % dims) + + if data_format_src == data_format_dst: + return dims + + dim_map = {d: i for i, d in enumerate(data_format_src)} + permuted_dims = [dims[dim_map[d]] for d in data_format_dst] + return permuted_dims diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index e610b63e301c75f532db1b58cd26533effea174d..6a7011aea6cc3f942fecf27a640b998bfc10c0de 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops @@ -44,11 +44,16 @@ def nhwc_to_format(x, data_format): raise ValueError("Unknown format {}".format(data_format)) -class UnaryOpsTest(XLATestCase): +class UnaryOpsTest(xla_test.XLATestCase): """Test cases for unary operators.""" - def _assertOpOutputMatchesExpected(self, op, inp, expected, - equality_test=None, rtol=1e-3, atol=1e-5): + def _assertOpOutputMatchesExpected(self, + op, + inp, + expected, + equality_test=None, + rtol=1e-3, + atol=1e-5): """Verifies that 'op' produces 'expected' when fed input 'inp' . Args: @@ -81,10 +86,10 @@ class UnaryOpsTest(XLATestCase): def testAllTypeOps(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( - array_ops.diag, - np.array([1, 2, 3, 4], dtype=dtype), - np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], - dtype=dtype)) + array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), @@ -102,8 +107,7 @@ class UnaryOpsTest(XLATestCase): expected=np.array([[-1, 1]], dtype=dtype)) self._assertOpOutputMatchesExpected( - array_ops.matrix_diag, - np.array([[1, 2], [3, 4]], dtype=dtype), + array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), @@ -115,10 +119,10 @@ class UnaryOpsTest(XLATestCase): np.array( [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), np.array( - [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], - [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], - [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], - [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [ + 0, 0, 6 + ]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0], + [0, 0, 12]]]], dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, @@ -159,36 +163,30 @@ class UnaryOpsTest(XLATestCase): continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( - math_ops.acos, - x.astype(dtype), - expected=np.arccos(x).astype(dtype)) + math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) self._assertOpOutputMatchesExpected( - math_ops.asin, - x.astype(dtype), - expected=np.arcsin(x).astype(dtype)) + math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype)) x = np.arange(-3, 3).reshape(1, 3, 2) self._assertOpOutputMatchesExpected( - math_ops.atan, - x.astype(dtype), - expected=np.arctan(x).astype(dtype)) + math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype)) self._assertOpOutputMatchesExpected( math_ops.acosh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([0, 1.3169579, 1.76274717, 2.06343707], - dtype=dtype)) + expected=np.array( + [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.asinh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([0.88137359, 1.44363548, 1.81844646, 2.09471255], - dtype=dtype)) + expected=np.array( + [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.atanh, np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), - expected=np.array([0.10033535, 0.20273255, 0.3095196, 0.42364893], - dtype=dtype)) + expected=np.array( + [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.ceil, @@ -198,8 +196,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.cosh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284], - dtype=dtype)) + expected=np.array( + [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype)) # Disable float16 testing for now if dtype != np.float16: @@ -229,8 +227,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.is_finite, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool)) # Tests for tf.nn ops. @@ -271,16 +269,20 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.rint, - np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5]], dtype=dtype), - expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], - dtype=dtype)) + np.array( + [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], + dtype=dtype), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.round, - np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5]], dtype=dtype), - expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], - dtype=dtype)) + np.array( + [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], + dtype=dtype), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.rsqrt, @@ -289,10 +291,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.sigmoid, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.7310586, 0.7310586, 0.7310586, 0.7310586], [0.7310586, 0.880797, 0.95257413, 0.98201376]], @@ -306,8 +305,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.sinh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.17520119, 3.62686041, 10.01787493, 27.2899172], - dtype=dtype)) + expected=np.array( + [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.sqrt, @@ -317,15 +316,12 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.tan, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.55740772, -2.18503986, -0.14254654, 1.15782128], - dtype=dtype)) + expected=np.array( + [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.tanh, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.76159418, 0.76159418, 0.76159418, 0.76159418], [0.76159418, 0.96402758, 0.99505478, 0.99932933]], @@ -333,10 +329,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.log_softmax, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[-1.3862944, -1.3862944, -1.3862944, -1.3862944], [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], @@ -370,10 +363,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.softmax, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.25, 0.25, 0.25, 0.25], [0.032058604, 0.087144323, 0.23688284, 0.64391428]], @@ -382,8 +372,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), - expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]], - dtype=dtype)) + expected=np.array( + [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.is_finite, @@ -392,10 +382,23 @@ class UnaryOpsTest(XLATestCase): expected=np.array( [[True, False, True], [False, True, True]], dtype=np.bool)) + def quantize_and_dequantize_v2(x): + return array_ops.quantize_and_dequantize_v2( + x, -127, 127, signed_input=True, num_bits=8) + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2, + np.array([-1, -0.5, 0, 0.3], dtype=dtype), + expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + + def quantize_and_dequantize_v3(x): + return array_ops.quantize_and_dequantize_v3( + x, -127, 127, num_bits=8, signed_input=True, range_given=False) + self._assertOpOutputMatchesExpected( - lambda x: array_ops.quantize_and_dequantize_v2(x, -127, 127, True, 8), + quantize_and_dequantize_v3, np.array([-1, -0.5, 0, 0.3], dtype=dtype), - expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype)) + expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) def testComplexOps(self): for dtype in self.complex_types: @@ -576,13 +579,13 @@ class UnaryOpsTest(XLATestCase): for dtype in self.float_types: self._assertOpOutputMatchesExpected( math_ops.is_inf, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool)) self._assertOpOutputMatchesExpected( math_ops.is_nan, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) def testLogicalOps(self): @@ -599,14 +602,15 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), - np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], - dtype=np.float32), + np.array( + [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), expected=np.array([10., 26.], dtype=np.float32)) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) | - self.complex_tf_types) + types = ( + set([dtypes.bool, dtypes.int32, dtypes.float32]) + | self.complex_tf_types) for shape in shapes: for src_type in types: for dst_type in types: @@ -648,14 +652,11 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( rank_op, dtype(7), expected=np.int32(0)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [[], []], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [-1, 1], dtype=dtype), expected=np.int32(1)) + rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [[-1, 1]], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( rank_op, np.array([[-1], [1], [4]], dtype=dtype), @@ -720,97 +721,97 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): - return array_ops.depth_to_space(x, block_size=2, - data_format=data_format) + return array_ops.depth_to_space( + x, block_size=2, data_format=data_format) + return op for dtype in self.numeric_types: for data_format in ["NCHW", "NHWC"]: self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - data_format), - expected=nhwc_to_format(np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - data_format)) + nhwc_to_format( + np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format), + expected=nhwc_to_format( + np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), nhwc_to_format( - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype), + np.array( + [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], - dtype=dtype), - data_format)) + np.array( + [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), nhwc_to_format( - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - data_format), + np.array( + [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], + [13, 14, 15, 16]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - data_format)) + np.array( + [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], + [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], + dtype=dtype), data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): - return array_ops.space_to_depth(x, block_size=2, - data_format=data_format) + return array_ops.space_to_depth( + x, block_size=2, data_format=data_format) + return op for dtype in self.numeric_types: for data_format in ["NCHW", "NHWC"]: self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - data_format), - expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - data_format)) + nhwc_to_format( + np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - data_format), + nhwc_to_format( + np.array( + [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype), + np.array( + [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - data_format), + nhwc_to_format( + np.array( + [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], + [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - data_format)) + np.array( + [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], + [13, 14, 15, 16]]]], + dtype=dtype), data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected, - rtol=1e-6, - atol=9.1e-6) + nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: @@ -824,9 +825,10 @@ class UnaryOpsTest(XLATestCase): one = dtype(1) ten = dtype(10) self._assertSoftplusMatchesExpected([ - log_eps, log_eps - one, log_eps + one, log_eps - ten, - log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, - -log_eps - ten, -log_eps + ten], dtype) + log_eps, log_eps - one, log_eps + one, log_eps - ten, log_eps + ten, + -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten, + -log_eps + ten + ], dtype) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 2c09b03d5a35cde2c42d8a145781270c0c908587..dd2c252d383bca9c59033ac07e442b487e4975a6 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -20,12 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -36,7 +37,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training.gradient_descent import GradientDescentOptimizer -class VariableOpsTest(XLATestCase): +class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" def testOneWriteOneOutput(self): @@ -52,9 +53,7 @@ class VariableOpsTest(XLATestCase): with ops.control_dependencies([x]): y = v.read_value() self.assertAllClose( - np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, { - p: 1 - })) + np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {p: 1})) def testSparseRead0DIndices(self): for dtype in self.numeric_types: @@ -103,9 +102,9 @@ class VariableOpsTest(XLATestCase): x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], - [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], - ).astype(dtype), sess.run(x)) + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]] + ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] + ],).astype(dtype), sess.run(x)) def testShape(self): for dtype in self.numeric_types: @@ -206,6 +205,206 @@ class VariableOpsTest(XLATestCase): self.assertAllClose(update, result[1]) self.assertAllClose(update, result[2]) + def testScatterAdd(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[2, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1], [7]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[3], [7]]) + + def testScatterSub(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[2, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_sub( + handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[4], [-1]]) + + def testScatterMul(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_mul( + handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[5]]) + + def testScatterDiv(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_div( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[2]]) + + def testScatterMin(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_min( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterMax(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_max( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[6]]) + + def testScatterUpdate(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_update( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterAddScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant(2, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterSubScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_sub( + handle, [0], constant_op.constant(2, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[-1]]) + + def testScatterMulScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_mul( + handle, [0], constant_op.constant(5, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[5]]) + + def testScatterDivScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_div( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[2]]) + + def testScatterMinScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_min( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterMaxScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_max( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[6]]) + + def testScatterNdAddOps(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.float32, shape=[8]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) + sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) + read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.float32) + self.assertAllClose(expected, sess.run(read)) + + def testScatterNdUpdateAddOps(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.float32, shape=[8]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + expected = np.array([1, 11, 1, 10, 9, 1, 1, 12]) + sess.run( + gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) + read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.float32) + self.assertAllClose(expected, sess.run(read)) + class StridedSliceAssignChecker(object): """Compares the results of a slice assignment using Tensorflow and numpy.""" @@ -236,12 +435,12 @@ class StridedSliceAssignChecker(object): self.test.assertAllEqual(val, valnp) -class SliceAssignTest(XLATestCase): +class SliceAssignTest(xla_test.XLATestCase): def testSliceAssign(self): for dtype in self.numeric_types: - checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]], - dtype=dtype) + checker = StridedSliceAssignChecker( + self, [[1, 2, 3], [4, 5, 6]], dtype=dtype) # No-op assignment checker[:] = [[10, 20, 30], [40, 50, 60]] # Checks trivial (1,1) shape tensor diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index f79eb27435cc954cebde4357c1d946a320f4ed75..b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class WhileTest(XLATestCase): +class WhileTest(xla_test.XLATestCase): def testSingletonLoopHandrolled(self): # Define a function for the loop body diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f0b010fa67f2ffb3f81fd14d4d89585f716b4890..06d977b93c28792704b910c688af510bc650d2a4 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(XLATestCase): +class XlaDeviceTest(xla_test.XLATestCase): def testCopies(self): """Tests that copies onto and off XLA devices work.""" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cd57452302fcbde37d79ce760a80615a76d7ad8c..aa9c0596d158386c6149fd5d6cfb2236324813ab 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", @@ -462,3 +463,13 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_cc_test( + name = "xla_op_registry_test", + srcs = ["xla_op_registry_test.cc"], + deps = [ + ":xla_compiler", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1438f6b48c4913e60b0c0a9f5c3d67fe595cbfe8..6cc95149a16a59fce8486c5d103ad09e3e262765 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -166,6 +166,27 @@ StatusOr AddNode(const NodeDef& node_def, Graph* graph) { return inserted_node; } +// Check that the graph has no cycle containing the given node. +Status CheckNoCycleContains(const Node* node, const int num_nodes) { + std::vector ready; + ready.push_back(node); + std::vector visited(num_nodes); + while (!ready.empty()) { + const Node* current_node = ready.back(); + ready.pop_back(); + visited[current_node->id()] = true; + for (const Edge* out : current_node->out_edges()) { + if (out->dst() == node) { + return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(", + node->def().op(), ") feeds into itself."); + } else if (!visited[out->dst()->id()]) { + ready.push_back(out->dst()); + } + } + } + return Status::OK(); +} + StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); @@ -1407,6 +1428,10 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( TF_RETURN_IF_ERROR( AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); + // Check that the if_node doesn't feed into itself. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(if_node, graph_->num_node_ids()), + "ConvertToXlaIf failed."); return if_node; } @@ -1439,7 +1464,9 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, // invariant. std::vector cf_info; std::vector unreachable_nodes; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes), + "FunctionalizeControlFlow failed"); if (!unreachable_nodes.empty()) { return errors::InvalidArgument( "The following nodes are unreachable from the source in the graph: ", @@ -1464,10 +1491,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, frame.parent = parent; frame.name = cf.frame_name; ++parent->num_children; - } else if (frame.parent != parent) { - return errors::InvalidArgument("Mismatched parent frames for ", - cf.frame->id(), ": ", parent->name, " vs ", - frame.parent->name); } if (IsEnter(node)) { @@ -1477,12 +1500,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, &arg.is_loop_invariant)); frame.args.push_back(arg); } else if (IsLoopCond(node)) { - if (frame.loop_cond) { - return errors::InvalidArgument( - "Loop ", cf.frame_name, - " has more than one LoopCond node: ", node->name(), " and ", - frame.loop_cond->name()); - } frame.loop_cond = node; } frame.nodes.insert(node); @@ -1514,6 +1531,16 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, worklist.push_back(frame->parent); } } + // There should be no cycle at this point, since while loops have been removed + // from graph. + // Check that the newly added XlaWhile nodes don't feed into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "XlaWhile") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(node, graph->num_node_ids()), + "FunctionalizeLoop failed."); + } + } // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 14977a908ae2b0ff7e13b634c41b6d331b4b8a36..aae2f8ee5acd6249f8b6002d94c877f18064f936 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -1012,5 +1013,60 @@ TEST(FunctionalizeControlFlow, Complex) { } } +TEST(FunctionalizeControlFlow, Cycle) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + // ----------------------------------------------------- + // | | + // | v + // less -> switch_1 --> add -> merge_1 -> identity -> switch_2 + // | ^ | + // | | v + // --------> one -------------------------> add_2 ---> merge_2 + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); + auto two = + ops::Const(scope.WithOpName("cond/two") + .WithControlDependencies(switch_1.output_true), + 2); + auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"), + switch_1.output_true, two); + auto one = + ops::Const(scope.WithOpName("cond/one") + .WithControlDependencies(switch_1.output_false), + 1); + auto add = ops::Add(scope.WithOpName("cond/false/add"), + switch_1.output_false, one); + + auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"), + std::initializer_list{add, mul}); + auto identity = + ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output); + auto switch_2 = + ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less); + auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"), + switch_2.output_false, one); + auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"), + switch_2.output_true, two); + auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"), + std::initializer_list{add_2, mul_2}); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + // No cycle before functionalize control flow. + TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph)); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // switch_1 and switch_2 have the same switch depth. They are replaced by a + // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle: + // less -> XlaIf <--> identity. + Status status = FunctionalizeControlFlow(graph.get(), &library); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle")) + << status.error_message(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 212f6f3966149ca0b2d2e012b19300e1f488f996..4900af6df17f360630abb1e64b7f144ccd4a0289 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -87,6 +89,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, } } // namespace Status GraphCompiler::Compile() { + // Check that the graph has no illegal cycles. + TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_)); // Maintain a mapping from node id to node outputs. using NodeOutputs = std::vector; std::vector output_registry(graph_->num_node_ids()); @@ -227,7 +231,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, XlaContext& context = XlaContext::Get(op_context); auto* b = context.builder(); - auto output_handle = b->Call(*result.computation, handles); + auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. int computation_output = 0; @@ -236,7 +240,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { xla_op_context.SetOutput( - i, b->GetTupleElement(output_handle, computation_output)); + i, xla::GetTupleElement(output_handle, computation_output)); ++computation_output; } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index edd2ab6301ee891c433639ce300cde0c72929cea..e6cbf2349d757179584343359e395c4f67f73c00 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -79,14 +79,17 @@ tf_kernel_library( "shape_util.cc", "slice_op.cc", "softmax_op.cc", + "sort_ops.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", + "sparse_to_dense_op.cc", "split_op.cc", "stack_ops.cc", "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", + "topk_op.cc", "training_ops.cc", "transpose_op.cc", "unary_ops.cc", @@ -104,6 +107,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", @@ -117,6 +121,7 @@ tf_kernel_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 1e59868621475cf72f4cc8b14dafec2dd8cd5c95..e33532828040123243f839ab1aa655b4bbc72520 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -31,7 +32,7 @@ class AddNOp : public XlaOpKernel { xla::XlaOp sum = ctx->Input(0); for (int i = 1; i < ctx->num_inputs(); ++i) { - sum = ctx->builder()->Add(sum, ctx->Input(i)); + sum = xla::Add(sum, ctx->Input(i)); } ctx->SetOutput(0, sum); diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index b0ba25b9983c3a9af26728ce4b1c263c844327db..4cfe946b2e6146f034867c06e996ffae42b90705 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,11 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + auto result = BatchDot(ctx->Input(0), ctx->Input(1), /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 15e1815a4cf07ff50dd1431b6790d14781da590f..c4af79281d2162b1dbfb0a7881720892f4bc49d2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -34,10 +35,11 @@ class FusedBatchNormOp : public XlaOpKernel { ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || + data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), errors::InvalidArgument( "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC and NCHW")); + "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -48,8 +50,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); - xla::XlaBuilder* builder = ctx->builder(); - xla::XlaOp input = ctx->Input(0); TensorShape input_shape = ctx->InputShape(0); @@ -59,30 +59,30 @@ class FusedBatchNormOp : public XlaOpKernel { // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. As a workaround, cast everything to the statistics type (which // may be more precise than the input type). - input = builder->ConvertElementType(input, scale_type); + input = xla::ConvertElementType(input, scale_type); if (is_training_) { - xla::XlaOp output = builder->BatchNormTraining( + xla::XlaOp output = xla::BatchNormTraining( input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - ctx->SetOutput(0, builder->ConvertElementType( - builder->GetTupleElement(output, 0), input_type)); - ctx->SetOutput(1, builder->GetTupleElement(output, 1)); - ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0), + input_type)); + ctx->SetOutput(1, xla::GetTupleElement(output, 1)); + ctx->SetOutput(2, xla::GetTupleElement(output, 2)); // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, builder->GetTupleElement(output, 1)); - ctx->SetOutput(4, builder->GetTupleElement(output, 2)); + ctx->SetOutput(3, xla::GetTupleElement(output, 1)); + ctx->SetOutput(4, xla::GetTupleElement(output, 2)); } else { - xla::XlaOp output = builder->BatchNormInference( + xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); - ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); + ctx->SetOutput(0, xla::ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -111,10 +111,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || + data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), errors::InvalidArgument( "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC and NCHW")); + "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -142,12 +143,12 @@ class FusedBatchNormGradOp : public XlaOpKernel { xla::XlaOp offset_backprop; if (is_training_) { xla::XlaOp output = - b->BatchNormGrad(activations, scale, mean, var, grad_backprop, - epsilon_, feature_index); + xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); - x_backprop = b->GetTupleElement(output, 0); - scale_backprop = b->GetTupleElement(output, 1); - offset_backprop = b->GetTupleElement(output, 2); + x_backprop = xla::GetTupleElement(output, 0); + scale_backprop = xla::GetTupleElement(output, 1); + offset_backprop = xla::GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. std::vector reduction_dims(input_dims - 1); @@ -164,35 +165,35 @@ class FusedBatchNormGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); - auto scratch1 = - b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + auto scratch1 = xla::Pow( + xla::Add(var, xla::ConstantR0(b, epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) auto mul = - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = - b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); - scale_backprop = b->Mul(scratch1, scratch2); + xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); + scale_backprop = xla::Mul(scratch1, scratch2); } ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); - ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); - ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(3, Tensor()); + ctx->SetConstantOutput(4, Tensor()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 642278ab994bf3cc84396f093ed56b009a1435c1..26130fd9e7fce75c6d2a5a53cfc85842cf762b35 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -45,7 +46,6 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ", 2] instead of ", xla::ShapeUtil::HumanString(crops.shape()))); - xla::XlaBuilder* b = ctx->builder(); const int64 batch_size = input_shape[0]; // Compute the product of the block_shape values. @@ -72,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, reshaped_shape[block_rank] = batch_size / block_num_elems; std::copy(input_shape.begin() + 1, input_shape.end(), reshaped_shape.begin() + block_rank + 1); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce `permuted` of shape // [batch / prod(block_shape), @@ -90,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, } std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::XlaOp permuted = b->Transpose(reshaped, permutation); + xla::XlaOp permuted = xla::Transpose(reshaped, permutation); // 3. Reshape `permuted` to produce `reshaped_permuted` of shape // [batch / prod(block_shape), @@ -110,7 +110,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_permuted_shape.begin() + 1 + block_rank); - xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape); + xla::XlaOp reshaped_permuted = + xla::Reshape(permuted, reshaped_permuted_shape); // 4. Crop the start and end of dimensions `[1, ..., M]` of // `reshaped_permuted` according to `crops` to produce the output of shape: @@ -138,7 +139,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); } xla::XlaOp output = - b->Slice(reshaped_permuted, start_indices, end_indices, strides); + xla::Slice(reshaped_permuted, start_indices, end_indices, strides); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 9d677f426650ea17a49e5ab1401078f04623fe97..e9b2c0b16d39cb3b747c0316621fb01de709b12e 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" @@ -60,8 +61,7 @@ class BiasOp : public XlaOpKernel { "of the input tensor: ", bias_shape.DebugString(), " vs. ", input_shape.DebugString())); - xla::XlaOp result = - ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); + xla::XlaOp result = xla::Add(ctx->Input(0), ctx->Input(1), {feature_dim}); ctx->SetOutput(0, result); } @@ -109,8 +109,8 @@ class BiasAddGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index f04cde878e98002d9442e0f3ec251c5197ef7969..d6d4ae89376b67c14af8ef4f3a608fcc83b6fb59 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -41,18 +41,19 @@ namespace { const BCast& broadcast_helper, \ const std::vector& extend_dimensions) override { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ return HLO; \ } \ }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) -XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { @@ -67,13 +68,13 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); - auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero)); - auto abs_x = b->Abs(x); - auto abs_y = b->Abs(y); - auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); - auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); + auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); + auto abs_x = xla::Abs(x); + auto abs_y = xla::Abs(y); + auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one)); + auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y)); if (DataTypeIsFloating(dtype)) { - result = b->Floor(result); + result = xla::Floor(result); } return result; } @@ -87,75 +88,78 @@ static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); - auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); - auto trunc_mod = b->Rem(x, y); - return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y)); + auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); + auto trunc_mod = xla::Rem(x, y); + return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y)); } XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(RightShift, (DataTypeIsUnsigned(ctx->input_type(0)) - ? b->ShiftRightLogical(lhs, rhs, extend_dimensions) - : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions))); - -XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs)))); + ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions) + : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions))); + +XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs)))); XLA_MAKE_BINARY( RsqrtGrad, - b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), - b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), - extend_dimensions)); -XLA_MAKE_BINARY(SqrtGrad, - b->Div(b->Mul(rhs, - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - lhs, extend_dimensions)); + xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), + xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), + extend_dimensions)); +XLA_MAKE_BINARY( + SqrtGrad, + xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + lhs, extend_dimensions)); static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return builder->Mul(x, x); + return xla::Mul(x, x); } XLA_MAKE_BINARY(SquaredDifference, - Square(b, b->Sub(lhs, rhs, extend_dimensions))); + Square(b, xla::Sub(lhs, rhs, extend_dimensions))); -XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); // Comparison ops -XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions)); // Non-linear ops XLA_MAKE_BINARY(SigmoidGrad, - b->Mul(b->Mul(rhs, lhs), - b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); + xla::Mul(xla::Mul(rhs, lhs), + xla::Sub(XlaHelpers::One(b, input_type(0)), lhs))); XLA_MAKE_BINARY(SoftplusGrad, - b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); + xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, - b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)), - b->Abs(rhs))))); + xla::Div(lhs, + Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); -XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(lhs, lhs)))); +XLA_MAKE_BINARY(TanhGrad, + xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(lhs, lhs)))); -XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); #undef XLA_MAKE_BINARY @@ -168,12 +172,13 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); + auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1))); auto abs_shape = b->GetShape(abs); OP_REQUIRES_OK(ctx, abs_shape.status()); auto abs_type = abs_shape.ValueOrDie().element_type(); - auto result = b->Lt( - abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); + auto result = + xla::Lt(abs, xla::ConvertElementType( + xla::ConstantR0(b, tolerance_), abs_type)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index ca9a6b40688d1e8496d1b823e20d273d519f65e8..efbdb76eaaf78904fe783a018940b1b096ec39bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -36,22 +37,22 @@ class BucketizeOp : public XlaOpKernel { const DataType dtype = context->input_type(0); xla::XlaOp input = context->Input(0); - xla::XlaOp boundaries = builder->ConstantR1(boundaries_); + xla::XlaOp boundaries = xla::ConstantR1(builder, boundaries_); // TODO(phawkins): the following behavior matches the behavior of the core // Bucketize kernel. However, comparing an int32 or int64 against float may // lead to inaccurate bucketing due to rounding. if (dtype == DT_DOUBLE) { - input = builder->ConvertElementType(input, xla::F64); - boundaries = builder->ConvertElementType(boundaries, xla::F64); + input = xla::ConvertElementType(input, xla::F64); + boundaries = xla::ConvertElementType(boundaries, xla::F64); } else { - input = builder->ConvertElementType(input, xla::F32); + input = xla::ConvertElementType(input, xla::F32); } - xla::XlaOp comparison = builder->ConvertElementType( - builder->Ge(builder->Broadcast(input, {1}), boundaries, - /*broadcast_dimensions=*/{0}), - xla::S32); - xla::XlaOp buckets = builder->Reduce( - comparison, /*init_value=*/builder->ConstantR0(0), + xla::XlaOp comparison = + xla::ConvertElementType(xla::Ge(xla::Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = xla::Reduce( + comparison, /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); context->SetOutput(0, buckets); diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index e9d98c768572c52825fa5192ecec834889f040fe..62eebf762b3e063da8ec456cc4726d3cc9b77d1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -40,14 +41,14 @@ class CastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else if (dst_dtype_ == DT_BOOL) { - output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); + output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_)); } else if (xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_)) { // As in cast_op.h, we replicate the numpy behavior of truncating the // imaginary part. - output = builder->ConvertElementType(builder->Real(input), dst_type_); + output = xla::ConvertElementType(xla::Real(input), dst_type_); } else { - output = builder->ConvertElementType(input, dst_type_); + output = xla::ConvertElementType(input, dst_type_); } ctx->SetOutput(0, output); @@ -72,7 +73,6 @@ class BitcastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); xla::XlaOp input = ctx->Input(0); xla::XlaOp output; @@ -92,7 +92,7 @@ class BitcastOp : public XlaOpKernel { xla::primitive_util::BitWidth(dst_type_), errors::Unimplemented( "Only bitcasts between equally sized types supported.")); - output = builder->BitcastConvertType(input, dst_type_); + output = xla::BitcastConvertType(input, dst_type_); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 835a7f568945f0bee86fe2b39491c3326726e1aa..c137d026bda7d9263d6bec85b13d5ce1dc040038 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -65,17 +66,17 @@ class CategoricalOp : public XlaOpKernel { DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); xla::Shape uniform_shape = xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); - auto uniforms = builder->RngUniform( - XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. auto softmax_entries = - builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))), - /*broadcast_dimensions=*/{0, 2}); + xla::Sub(logits, xla::Log(xla::Neg(xla::Log(uniforms))), + /*broadcast_dimensions=*/{0, 2}); TensorShape softmax_shape(uniform_shape_array); xla::XlaOp argmax; diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index fe6651793dc763d13f4a4b0ac294ec3ecf64af8f..9fcbc86adc0967cbb7fb73da8bdabc58b60953da 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -24,12 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - auto result = Cholesky(ctx->builder(), ctx->Input(0)); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index a00bc912f9f40052565446c6bf9390629af9a4cd..4e6d33304c4ae08a0fd1e0a8373267a527087528 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -29,7 +30,6 @@ class ClipByValueOp : public XlaOpKernel { const TensorShape min_shape = ctx->InputShape(1); const TensorShape max_shape = ctx->InputShape(2); - xla::XlaBuilder* builder = ctx->builder(); auto input = ctx->Input(0); auto min = ctx->Input(1); auto max = ctx->Input(2); @@ -45,13 +45,13 @@ class ClipByValueOp : public XlaOpKernel { if (shape != min_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); - min = builder->Broadcast(min, shape.dim_sizes()); + min = xla::Broadcast(min, shape.dim_sizes()); } if (shape != max_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); - max = builder->Broadcast(max, shape.dim_sizes()); + max = xla::Broadcast(max, shape.dim_sizes()); } - ctx->SetOutput(0, builder->Clamp(min, input, max)); + ctx->SetOutput(0, xla::Clamp(min, input, max)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 78285affa1c399ae107a9172fb85cf257457c368..e3a32a5c0e2f93237c8c7ebeea3668b5d1ab6c23 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -88,7 +89,7 @@ class ConcatBaseOp : public XlaOpKernel { "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. - input_data.push_back(ctx->builder()->Reshape(handle, {1})); + input_data.push_back(xla::Reshape(handle, {1})); } else { input_data.push_back(handle); } @@ -96,7 +97,7 @@ class ConcatBaseOp : public XlaOpKernel { } VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 59d06c654de18c9003fe0bdc706d0c2443de6d7b..f4360d8c3f6fc4007c31fdcfd7f7634de15c76d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -53,41 +54,41 @@ class ConstOp : public XlaOpKernel { switch (proto_.dtype()) { case DT_BOOL: if (proto_.bool_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0(proto_.bool_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0(b, proto_.bool_val(0)), + shape.dim_sizes())); return; } break; case DT_FLOAT: if (proto_.float_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.float_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.float_val(0)), + shape.dim_sizes())); return; } break; case DT_DOUBLE: if (proto_.double_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.double_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.double_val(0)), + shape.dim_sizes())); return; } break; case DT_INT32: if (proto_.int_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0(proto_.int_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0(b, proto_.int_val(0)), + shape.dim_sizes())); return; } break; case DT_INT64: if (proto_.int64_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.int64_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.int64_val(0)), + shape.dim_sizes())); return; } break; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 627bad12f33c82e91bc3c6f3323f562bc8174056..48ac4867edcef97be001a24f42f6a35225d466c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -51,8 +53,8 @@ xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); + return xla::Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); } // Create a mask for depthwise convolution that will make a normal convolution @@ -95,32 +97,27 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, - &input_feature_iota)); - xla::XlaOp expanded_feature_iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - input_feature * depthwise_multiplier, - &expanded_feature_iota)); + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); // Divide the M*N sized linspace by the depthwise_multiplier to create // [0 0 1 1 2 2] in the example in the function comment. expanded_feature_iota = - builder->Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); // Broadcast the N*M linspace to [H, W, ..., M, M*N]. auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = builder->Broadcast( - expanded_feature_iota, expanded_feature_broadcast_dims); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); // Compare the broadcasted linspace to the input feature linspace in the // input feature dimension to create a diagonal predicate. - return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); } // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding @@ -142,16 +139,16 @@ xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, implicit_broadcast_filter_shape.dims() - 1, depthwise_multiplier * input_feature); auto implicit_broadcast_filter = - builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); // Broadcast the filter to [H, W, ..., M, M*N]. auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); - auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero); // If the filter mask is set, choose the broadcasted filter, othwerwise, // choose zero. - return builder->Select(CreateExpandedFilterMask(filter_shape, builder), - expanded_filter, expanded_zero); + return xla::Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. @@ -162,17 +159,17 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - auto masked_expanded_filter = builder->Select( + auto masked_expanded_filter = xla::Select( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); - return builder->Reshape( + return xla::Reshape( // This reduce does not need inputs to be converted with // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with // ExpandedZero guarantees that only one element is non zero, so there // cannot be accumulated precision error. - builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), - {expanded_filter_shape.dims() - 2}), + xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), filter_shape.dim_sizes()); } @@ -289,8 +286,8 @@ class ConvOp : public XlaOpKernel { } xla::XlaOp conv = - b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } @@ -435,11 +432,11 @@ class ConvBackpropInputOp : public XlaOpKernel { } // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims); + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights - xla::XlaOp in_backprop = b->ConvGeneralDilated( + xla::XlaOp in_backprop = xla::ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums); @@ -638,8 +635,8 @@ class ConvBackpropFilterOp : public XlaOpKernel { // This is done by specifying the window dilation factors in the // convolution HLO below. auto filter_backprop = - b->ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); if (depthwise_) { filter_backprop = ContractFilterForDepthwiseBackprop( diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 7fcd4170fb79a574663c1abffe873d4b53f471d3..500a564f3f0489a42dbc9d5b70ae7708a7a43973 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -58,21 +59,21 @@ class CrossOp : public XlaOpKernel { auto in1 = ctx->Input(1); starts.back() = 0; limits.back() = 1; - auto u1 = b->Slice(in0, starts, limits, strides); - auto v1 = b->Slice(in1, starts, limits, strides); + auto u1 = xla::Slice(in0, starts, limits, strides); + auto v1 = xla::Slice(in1, starts, limits, strides); starts.back() = 1; limits.back() = 2; - auto u2 = b->Slice(in0, starts, limits, strides); - auto v2 = b->Slice(in1, starts, limits, strides); + auto u2 = xla::Slice(in0, starts, limits, strides); + auto v2 = xla::Slice(in1, starts, limits, strides); starts.back() = 2; limits.back() = 3; - auto u3 = b->Slice(in0, starts, limits, strides); - auto v3 = b->Slice(in1, starts, limits, strides); + auto u3 = xla::Slice(in0, starts, limits, strides); + auto v3 = xla::Slice(in1, starts, limits, strides); - auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2)); - auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3)); - auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1)); - auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1); + auto s1 = xla::Sub(xla::Mul(u2, v3), xla::Mul(u3, v2)); + auto s2 = xla::Sub(xla::Mul(u3, v1), xla::Mul(u1, v3)); + auto s3 = xla::Sub(xla::Mul(u1, v2), xla::Mul(u2, v1)); + auto output = xla::ConcatInDim(b, {s1, s2, s3}, in0_shape.dims() - 1); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 01aa1a83e7967921f1583b3ef18ec57e452dcfea..9ff3e0222831cb4339943966810eeae451e47a2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -96,18 +96,16 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // First reshape the inputs, which should be a metadata-only // operation since we are flattening the dimensions in order. - auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape()); + auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); + auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); // Next broadcast the necessary input dimensions. We rely on the // XLA optimizer to be smart about the fact that we are asking // it to broadcast size 1 on some of these dimensions, to avoid // adding complexity to this code. - auto lhs_broadcast = - builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast()); + auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = - builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast()); + auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); int rhs_size = broadcast_helper.y_bcast().size(); // Now reshape them to the correct output shape. After the @@ -122,15 +120,15 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { lhs_reorder.push_back(i); lhs_reorder.push_back(i + lhs_size); } - auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder, - broadcast_helper.output_shape()); + auto lhs_output = + xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); std::vector rhs_reorder; for (int i = 0; i < rhs_size; ++i) { rhs_reorder.push_back(i); rhs_reorder.push_back(i + rhs_size); } - auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder, - broadcast_helper.output_shape()); + auto rhs_output = + xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); return {lhs_output, rhs_output}; } diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 23243f62462c6315e359d9621823b19fc98c6218..f3149200250935629a6e4bf67bff0c048135ce3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class DepthToSpaceOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel { ") is not divisible by square of the block size (", block_size_, ")")); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -141,7 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2], // block_size_, // depth / (block_size_ * block_size_)] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -151,7 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 931705ba837153e1175cd9a209876ef5ec93f0fc..378b62c0d613c61e6438d1daa8977daff75ea0c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -38,16 +40,14 @@ xla::StatusOr CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::XlaOp iota; - TF_RETURN_IF_ERROR( - XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); - xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size}); - xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0}); + xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size); + xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size}); + xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0}); // If this is a batched diagonal, broadcast the mask across the other // dimensions. if (!other_dims.empty()) { - mask = builder->Broadcast(mask, other_dims); + mask = xla::Broadcast(mask, other_dims); } // Broadcast the input, and then use the mask computed above to select the @@ -64,7 +64,7 @@ xla::StatusOr CreateDiagonal( std::vector broadcast_dims(other_dims.begin(), other_dims.end()); broadcast_dims.push_back(1LL); broadcast_dims.push_back(last_dim_size); - xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims); + xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; xla::PrimitiveType element_type; @@ -74,8 +74,8 @@ xla::StatusOr CreateDiagonal( xla::ShapeUtil::MakeShape(element_type, broadcast_dims); xla::XlaOp zeros = Zeros(builder, broadcast_shape); - input_broadcast = builder->Add(input_broadcast, zeros); - return builder->Select(mask, input_broadcast, zeros); + input_broadcast = xla::Add(input_broadcast, zeros); + return xla::Select(mask, input_broadcast, zeros); } class DiagOp : public XlaOpKernel { @@ -104,7 +104,7 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - input = builder->Reshape(input, {size}); + input = xla::Reshape(input, {size}); // Create an R2 with the R1 diagonal. auto diag_or_status = @@ -116,7 +116,7 @@ class DiagOp : public XlaOpKernel { std::vector new_dims(dims.size() * 2); std::copy(dims.begin(), dims.end(), new_dims.begin()); std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size()); - diag = builder->Reshape(diag, new_dims); + diag = xla::Reshape(diag, new_dims); ctx->SetOutput(0, diag); } @@ -170,21 +170,21 @@ class DiagPartOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + diag = xla::Reshape(diag, {size}); // Adds padding after the last element of 'new_size'. xla::PaddingConfig config; auto* dim = config.add_dimensions(); dim->set_edge_padding_high(new_size); auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); + diag = xla::Pad(diag, zero, config); // Reshapes so the diagonal is now in the first column. - diag = builder->Reshape(diag, {new_size, new_size + 1}); + diag = xla::Reshape(diag, {new_size, new_size + 1}); // Slices out the first column and reshapes to the final shape. - diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); - diag = builder->Reshape(diag, new_dims); + diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); + diag = xla::Reshape(diag, new_dims); ctx->SetOutput(0, diag); } @@ -265,7 +265,7 @@ class MatrixDiagPartOp : public XlaOpKernel { // Collapses the last two dimensions. std::vector flattened_dims(dims.begin(), dims.end() - 1); flattened_dims.back() *= dims.back(); - diag = builder->Reshape(diag, flattened_dims); + diag = xla::Reshape(diag, flattened_dims); // Slices or pads the last dimension to 'target_size'. int64 actual_size = flattened_dims.back(); @@ -276,13 +276,13 @@ class MatrixDiagPartOp : public XlaOpKernel { auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); dim->set_edge_padding_high(target_size - actual_size); auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); + diag = xla::Pad(diag, zero, config); } else if (actual_size > target_size) { std::vector start(flattened_dims.size(), 0); std::vector limits(flattened_dims.begin(), flattened_dims.end()); std::vector strides(flattened_dims.size(), 1); limits[flattened_dims.size() - 1] = target_size; - diag = builder->Slice(diag, start, limits, strides); + diag = xla::Slice(diag, start, limits, strides); } // Reshape so the target values are in the first position of the last @@ -290,18 +290,18 @@ class MatrixDiagPartOp : public XlaOpKernel { std::vector unflattened_dims(dims.begin(), dims.end()); dims[last_dim - 1] = smaller_dim_size; dims[last_dim] = last_dim_size + 1; - diag = builder->Reshape(diag, dims); + diag = xla::Reshape(diag, dims); // Slices out the first column and reshapes to the final shape. std::vector start(dims.size(), 0); std::vector limits(dims.begin(), dims.end()); std::vector strides(dims.size(), 1); limits[last_dim] = 1; - diag = builder->Slice(diag, start, limits, strides); + diag = xla::Slice(diag, start, limits, strides); // Collapses away the last dimension. dims.pop_back(); - diag = builder->Reshape(diag, dims); + diag = xla::Reshape(diag, dims); ctx->SetOutput(0, diag); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 0419de78b2ee83fd395e8bf23444fde84f30bba2..3b86ea34c9e7d943eb9c7de222e0a2be049ebc68 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -57,8 +57,8 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = ctx->builder()->DynamicUpdateSlice( - ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = + xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index dd4a16908779508380b36f43ce2306ff2f5fb8c4..958231505b50431b9bb267b0a3cc5ed56e3aeb21 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -150,8 +151,7 @@ class DynamicStitchOp : public XlaOpKernel { if (new_shape == data_shapes[input_num]) { input[input_num] = handle; } else { - input[input_num] = - ctx->builder()->Reshape(handle, new_shape.dim_sizes()); + input[input_num] = xla::Reshape(handle, new_shape.dim_sizes()); } } @@ -175,10 +175,10 @@ class DynamicStitchOp : public XlaOpKernel { // And place it in the concat list in the place indicated by // the index. to_concat[index_num] = - ctx->builder()->Slice(expression, slice_start, slice_limit, stride); + xla::Slice(expression, slice_start, slice_limit, stride); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), to_concat, 0)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 493781a1e68b8906f1a7e018e5710130e2eb08b5..2c76bcee2593b820eafe09af3a52736ed8a92f86 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -34,9 +34,9 @@ class EluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Expm1(ctx->Input(0)); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), expm1)); } }; @@ -51,9 +51,9 @@ class EluGradOp : public XlaOpKernel { const auto one = XlaHelpers::One(b, input_type(0)); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto exp_grad = b->Mul(grad, b->Add(activation, one)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, one)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, grad, exp_grad)); } }; @@ -71,10 +71,10 @@ class SeluOp : public XlaOpKernel { 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Expm1(ctx->Input(0)); - ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), - b->Mul(scale_alpha, expm1))); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, xla::Mul(scale, ctx->Input(0)), + xla::Mul(scale_alpha, expm1))); } }; @@ -92,10 +92,10 @@ class SeluGradOp : public XlaOpKernel { 1.7580993408473768599402175208123); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto lin_grad = b->Mul(grad, scale); - const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad)); + const auto lin_grad = xla::Mul(grad, scale); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, scale_alpha)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, lin_grad, exp_grad)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 6df01cabbf1d98c0299bfd808bcc6db6223c4777..65d42a302fca48c7b5f88813f80e975823f63ddf 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -110,13 +112,11 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::XlaOp iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - kernel_size * depth, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth); - auto lhs = builder->Reshape(iota, lhs_shape); - auto filter = builder->ConvertElementType( - builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); + auto lhs = xla::Reshape(iota, lhs_shape); + auto filter = xla::ConvertElementType( + xla::Eq(lhs, iota, {num_spatial_dims + 1}), type); xla::ConvolutionDimensionNumbers dims; std::vector window_strides(num_spatial_dims); @@ -148,8 +148,8 @@ class ExtractImagePatchesOp : public XlaOpKernel { } xla::XlaOp conv = - builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, - padding, lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 8f0de0a524c908b598c1a2165a462275346ad137..2fd1a34741e1c7235397f9a69dd8444b4679fa22 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -49,20 +50,20 @@ void XlaNudge(xla::XlaBuilder* b, const DataType data_type, const float quant_min_value, const float quant_max_value, xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, xla::XlaOp* scale) { - *scale = b->Div(b->Sub(max, min), - XlaHelpers::FloatLiteral(b, data_type, - quant_max_value - quant_min_value)); + *scale = xla::Div(xla::Sub(max, min), + XlaHelpers::FloatLiteral( + b, data_type, quant_max_value - quant_min_value)); xla::XlaOp quant_min = XlaHelpers::FloatLiteral(b, data_type, quant_min_value); - xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale)); + xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); xla::XlaOp quant_max = XlaHelpers::FloatLiteral(b, data_type, quant_max_value); xla::XlaOp nudged_zero_point = - b->Select(b->Le(zero_point_from_min, quant_min), quant_min, - b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, - b->Round(zero_point_from_min))); - *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale); - *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); + xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min, + xla::Select(xla::Ge(zero_point_from_min, quant_max), + quant_max, xla::Round(zero_point_from_min))); + *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale); + *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale); } xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, @@ -71,14 +72,14 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, const xla::XlaOp& nudged_input_max, const xla::XlaOp& input_scale) { xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); - xla::XlaOp inv_scale = b->Div(one, input_scale); + xla::XlaOp inv_scale = xla::Div(one, input_scale); xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); - xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max); - xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min); + xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max); + xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min); xla::XlaOp rounded = - b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); - return b->Add(b->Mul(rounded, input_scale), nudged_input_min); + xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half)); + return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min); } class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { @@ -163,11 +164,11 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type), - gradient_shape.dim_sizes()); - xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); + xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type), + gradient_shape.dim_sizes()); + xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -249,25 +250,25 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes()); - xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); - xla::XlaOp below_min = b->Lt(input, nudged_input_min); - xla::XlaOp select1 = b->Select(below_min, gradient, zeroes); - xla::XlaOp reduce1 = b->ReduceAll( + xla::XlaOp below_min = xla::Lt(input, nudged_input_min); + xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes); + xla::XlaOp reduce1 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); - xla::XlaOp above_max = b->Gt(input, nudged_input_max); - xla::XlaOp select2 = b->Select(above_max, gradient, zeroes); - xla::XlaOp reduce2 = b->ReduceAll( + xla::XlaOp above_max = xla::Gt(input, nudged_input_max); + xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes); + xla::XlaOp reduce2 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 933924cad1c7cac2879bd4720cb21ffc33c23f50..b2b00e51e3b00fa93c258af489cf0f4a3e6e764b 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,8 +63,7 @@ class GenericFftOp : public XlaOpKernel { } } - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length); ctx->SetOutput(0, fft); } diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index e4467a0fb138ed7919af62ed032c0f5abee3e4f6..95faa1d058f4c0d3fa802b157c6daba1e1adaf41 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -59,11 +60,11 @@ class FillOp : public XlaOpKernel { xla::XlaOp data = ctx->Input(1); if (value_shape.dims() > 0) { CHECK_EQ(value_shape.dims(), 1); - data = ctx->builder()->Reshape(data, {}); + data = xla::Reshape(data, {}); } // Emit the actual computation, which broadcasts the scalar to the // desired shape. - auto result = ctx->builder()->Broadcast(data, broadcast); + auto result = xla::Broadcast(data, broadcast); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index d13e25bcddae16d0cd630403219657121b80868d..5f041be5df226ed996b21844c0cf92b6dfac005c 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -75,8 +76,8 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, out_shape.AppendShape(indices_shape_no_index_vectors); out_shape.AppendShape(input_shape_post_axis); - *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); + *gather_output = + xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); return Status::OK(); } @@ -142,7 +143,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, dim_numbers.add_gather_dims_to_operand_dims(i); } - *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index d48c6eea754f75a8879d3938f233a6a591d26d0d..f5fcf3cacdbff8297bc42fcb0cf79c2bc83a4e11 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -199,13 +200,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp outputs = - b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, - b->Tuple(inputs), *else_result.computation); + xla::XlaOp outputs = xla::Conditional( + ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, + xla::Tuple(b, inputs), *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { - xla::XlaOp output_handle = b->GetTupleElement(outputs, i); + xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); if (VLOG_IS_ON(2)) { LOG(INFO) << "Setting output " << i; auto shape_or = b->GetShape(output_handle); @@ -233,7 +234,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - b->GetTupleElement(outputs, pos), b)); + xla::GetTupleElement(outputs, pos), b)); } VLOG(2) << "If variable: pos: " << update.input_index << " name: " << resource->name() diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 1568b33679963c1a6630525f60560180d40b8d53..cb4caf7bcb4caaa1bf7e0e79e52bb966a8838db3 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -32,23 +33,26 @@ std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, auto red = rgb[0]; auto green = rgb[1]; auto blue = rgb[2]; - auto value = b->Max(b->Max(red, green), blue); - auto minimum = b->Min(b->Min(red, green), blue); - auto range = b->Sub(value, minimum); - - auto zeros = b->Broadcast(zero, shape.dim_sizes()); - auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros); - - auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); - - auto hue = b->Select(b->Eq(green, value), - b->Add(b->Mul(norm, b->Sub(blue, red)), - XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), - b->Add(b->Mul(norm, b->Sub(red, green)), - XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); - hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue); - hue = b->Select(b->Gt(range, zero), hue, zeros); - hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue); + auto value = xla::Max(xla::Max(red, green), blue); + auto minimum = xla::Min(xla::Min(red, green), blue); + auto range = xla::Sub(value, minimum); + + auto zeros = xla::Broadcast(zero, shape.dim_sizes()); + auto saturation = + xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); + + auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); + + auto hue = + xla::Select(xla::Eq(green, value), + xla::Add(xla::Mul(norm, xla::Sub(blue, red)), + XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), + xla::Add(xla::Mul(norm, xla::Sub(red, green)), + XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); + hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)), + hue); + hue = xla::Select(xla::Gt(range, zero), hue, zeros); + hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue); return {hue, saturation, value}; } @@ -66,15 +70,15 @@ std::array HSVToRGB(xla::XlaBuilder* b, auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); - auto dh = b->Mul(hue, six); - auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one); - auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one); - auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one); - auto one_minus_s = b->Sub(one, saturation); + auto dh = xla::Mul(hue, six); + auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one); + auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one); + auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one); + auto one_minus_s = xla::Sub(one, saturation); - auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value); - auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value); - auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value); + auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value); + auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value); + auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value); return {red, green, blue}; } @@ -97,21 +101,21 @@ class RGBToHSVOp : public XlaOpKernel { xla::XlaBuilder* b = context->builder(); xla::XlaOp input = context->Input(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - context->SetOutput(0, b->ConcatInDim(hsv, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim)); } }; REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); @@ -134,20 +138,20 @@ class HSVToRGBOp : public XlaOpKernel { xla::XlaBuilder* b = context->builder(); xla::XlaOp input = context->Input(0); - xla::XlaOp hue = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp saturation = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp value = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); @@ -182,18 +186,20 @@ class AdjustContrastOpV2 : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, input, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *context->GetOrCreateAdd(accumulation_type), - {height_dim, width_dim}); + auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), + {height_dim, width_dim}); auto output = XlaHelpers::ConvertElementType(b, reduce, type); - output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + output = + xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector broadcast_dims(input_shape.dims() - 2); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); broadcast_dims.back() = channel_dim; - output = b->Add(b->Mul(input, factor), - b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)), - broadcast_dims); + output = + xla::Add(xla::Mul(input, factor), + xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)), + broadcast_dims); context->SetOutput(0, output); } }; @@ -226,26 +232,26 @@ class AdjustSaturationOp : public XlaOpKernel { DataType type = context->input_type(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale), - XlaHelpers::One(b, type)); + hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale), + XlaHelpers::One(b, type)); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); @@ -276,15 +282,15 @@ class AdjustHueOp : public XlaOpKernel { DataType type = context->input_type(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), @@ -294,12 +300,13 @@ class AdjustHueOp : public XlaOpKernel { auto one = XlaHelpers::One(b, type); auto& hue = hsv[0]; - hue = b->Rem(b->Add(hsv[0], delta), one); - hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue); + hue = xla::Rem(xla::Add(hsv[0], delta), one); + hue = + xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 79d3a6979cec4c6bda92a71dcff4ddd2151367d5..d6bf92fb3df8d38909df99e11c85ede4fac2bf81 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/math/math_util.h" @@ -127,48 +129,41 @@ const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); - - auto diag = builder->ConvertElementType( - builder->Eq( - builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + + auto diag = xla::ConvertElementType( + xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); - return builder->Mul( - builder->Mul(diag, - builder->ConstantR1(Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}), - builder->ConstantR1(Make1DKernel(kernel_size[0])), + return xla::Mul( + xla::Mul(diag, + xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}), + xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); - - auto diag = builder->ConvertElementType( - builder->Eq(builder->Broadcast( - channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + + auto diag = xla::ConvertElementType( + xla::Eq( + xla::Broadcast(channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); if (dim == 1) { - return builder->Mul( - diag, builder->ConstantR1(Make1DKernel(kernel_size[1])), + return xla::Mul( + diag, xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}); } - return builder->Mul(diag, - builder->ConstantR1(Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + return xla::Mul(diag, + xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -208,7 +203,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( input, kernel, dims.stride, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -218,7 +213,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( input, kernel0, {dims.stride[0], 1}, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, @@ -226,7 +221,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, @@ -238,8 +233,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, // size > 1 dimension. for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] == 1 && out_size[i] > 1) { - output = builder->Add(output, builder->ConstantR1(out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); + output = xla::Add(output, xla::ConstantR1(builder, out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); } } return output; @@ -279,12 +274,12 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] == 1 && grad_size[i] > 1) { kernel = - builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + xla::Add(kernel, xla::ConstantR1(builder, grad_size[i], 0), + /*broadcast_dimensions=*/{i}); } } - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -302,23 +297,23 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, // gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { kernel0 = - builder->Add(kernel0, builder->ConstantR1(grad_size[0], 0), - /*broadcast_dimensions=*/{0}); + xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), + /*broadcast_dimensions=*/{0}); } if (in_size[1] == 1 && grad_size[1] > 1) { kernel1 = - builder->Add(kernel0, builder->ConstantR1(grad_size[1], 0), - /*broadcast_dimensions=*/{1}); + xla::Add(kernel0, xla::ConstantR1(builder, grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, @@ -337,7 +332,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } } if (pad_output) { - output = builder->Pad(output, builder->ConstantR0(0.0f), padding); + output = xla::Pad(output, xla::ConstantR0(builder, 0.0f), padding); } return output; } @@ -393,13 +388,13 @@ class ResizeBilinearOp : public XlaOpKernel { } } if (slice_input) { - input = b->Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = xla::Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); } // Output is always type float. - input = b->ConvertElementType(input, xla::F32); + input = xla::ConvertElementType(input, xla::F32); // Special Case: // Instead of doing a ResizeUsingDilationAndConvolution directly, @@ -529,7 +524,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { } } - output = b->ConvertElementType(output, output_type_); + output = xla::ConvertElementType(output, output_type_); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 2c2d88486fda99d2380382a3e2f633f5bdc7478c..a020ebc729e4c07d1b182cc0585ba0f2bca46403 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // XLA passes to the function, so it is not included here. std::vector args; args.push_back(ctx->Input(0)); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(input_shape.dim_sizes()))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(output_shape.dim_sizes()))); - args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1(output_shape.dim_sizes()))); + args.push_back( + xla::ConstantLiteral(&b, *xla::Literal::CreateR0(dim))); } xla::Shape xla_shape = @@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); break; case 2: - output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 1decf7d72d72bb697477e7f841ced2a1a0d5fbe9..9e64711051d31107db1bf6f1966f9ed6f5630c34 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -39,12 +39,12 @@ class L2LossOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto t = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); - auto square = b->Mul(t, t); - auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), dims); + auto square = xla::Mul(t, t); + auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - ctx->SetOutput(0, b->Div(deconverted, two)); + ctx->SetOutput(0, xla::Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 0388b4c830702ea00ec69fc42c6468326c88cf38..2fb072f827906d40dcf410f0312394c4f568a28d 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" @@ -90,8 +91,10 @@ class ListDiffOp : public XlaOpKernel { idx_output.push_back(i); } - context->SetOutput(0, context->builder()->ConstantR1(val_output)); - context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + context->SetOutput(0, + xla::ConstantR1(context->builder(), val_output)); + context->SetOutput(1, + xla::ConstantR1(context->builder(), idx_output)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 39fbf98a6274918840e9e351470f04c2d80c5d01..dc934543cb2f94fbe1e8f1f865156eb082d6a127 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -50,8 +51,8 @@ class LRNOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, input, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -59,12 +60,12 @@ class LRNOp : public XlaOpKernel { auto sqr_sum = XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); - auto scale = builder->Pow( - builder->Add(builder->ConstantR0(bias_), - builder->Mul(builder->ConstantR0(alpha_), sqr_sum)), - builder->ConstantR0(-beta_)); + auto scale = xla::Pow( + xla::Add(xla::ConstantR0(builder, bias_), + xla::Mul(xla::ConstantR0(builder, alpha_), sqr_sum)), + xla::ConstantR0(builder, -beta_)); - ctx->SetOutput(0, builder->Mul(input, scale)); + ctx->SetOutput(0, xla::Mul(input, scale)); } private: @@ -138,8 +139,8 @@ class LRNGradOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -148,17 +149,17 @@ class LRNGradOp : public XlaOpKernel { XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = - builder->Add(builder->ConstantR0(bias_), - builder->Mul(builder->ConstantR0(alpha_), sqr_sum)); + xla::Add(xla::ConstantR0(builder, bias_), + xla::Mul(xla::ConstantR0(builder, alpha_), sqr_sum)); - auto dy = builder->Mul( - builder->Mul(builder->ConstantR0(-2.0f * alpha_ * beta_), - builder->Div(out_image, norm)), + auto dy = xla::Mul( + xla::Mul(xla::ConstantR0(builder, -2.0f * alpha_ * beta_), + xla::Div(out_image, norm)), in_grads); auto converted_dy = XlaHelpers::ConvertElementType(builder, dy, accumulation_type); - auto dy_reduce = builder->ReduceWindow( + auto dy_reduce = xla::ReduceWindow( converted_dy, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -166,10 +167,10 @@ class LRNGradOp : public XlaOpKernel { auto dy_reduced = XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); - xla::XlaOp gradients = builder->Add( - builder->Mul(in_image, dy_reduced), - builder->Mul(in_grads, - builder->Pow(norm, builder->ConstantR0(-beta_)))); + xla::XlaOp gradients = xla::Add( + xla::Mul(in_image, dy_reduced), + xla::Mul(in_grads, + xla::Pow(norm, xla::ConstantR0(builder, -beta_)))); ctx->SetOutput(0, gradients); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6949b296f4b9afe4a0c9152c763a9ad233b9f595..844080b8cf5462da201ce7671e4f9d02fa52c861 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -70,15 +71,15 @@ class MatMulOp : public XlaOpKernel { xla::XlaOp b = ctx->Input(1); if (is_sparse_) { if (a_type_ == DT_BFLOAT16) { - a = ctx->builder()->ConvertElementType(a, xla::F32); + a = xla::ConvertElementType(a, xla::F32); } if (b_type_ == DT_BFLOAT16) { - b = ctx->builder()->ConvertElementType(b, xla::F32); + b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs)); + auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; + auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; + ctx->SetOutput(0, xla::Dot(lhs, rhs)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index fbd5dc0fdad4483aadbe9bc263cc1f7a034cee09..e06c87db7adb1840606208fe15cd68a3ca4d137a 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -50,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel { xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); + xla::PrimitiveType index_xla_type = context->input_xla_type(1); TensorShape batch_shape = input_shape; batch_shape.RemoveLastDims(2); @@ -58,33 +61,29 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); + xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); + xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); - - auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, - /*broadcast_dimensions=*/{0}); + auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, + /*broadcast_dimensions=*/{0}); // If num_lower or num_upper are negative, include all lower/upper // diagonals. auto zero_index = XlaHelpers::Zero(builder, index_type); - num_lower = builder->Select( - builder->Lt(num_lower, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower); - num_upper = builder->Select( - builder->Lt(num_upper, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper); + num_lower = xla::Select(xla::Lt(num_lower, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, m), + num_lower); + num_upper = xla::Select(xla::Lt(num_upper, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, n), + num_upper); - auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset), - builder->Le(offset, num_upper)); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + auto indicator = xla::And(xla::Le(xla::Neg(num_lower), offset), + xla::Le(offset, num_upper)); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); auto zero_input = XlaHelpers::Zero(builder, input_type); - auto output = builder->Select( - indicator, input, - builder->Broadcast(zero_input, input_shape.dim_sizes())); + auto output = xla::Select( + indicator, input, xla::Broadcast(zero_input, input_shape.dim_sizes())); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index db53f6fef8d6bf901c8281f50791ca6766c46efd..e2ab4b83cfb45b2f9a7f3aba2d2a927d10ad8b85 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -61,14 +63,11 @@ class MatrixSetDiagOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); - auto indicator = builder->Eq(iota_m, - builder->Broadcast(iota_n, {m}), - /*broadcast_dimensions=*/{0}); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m); + xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n); + auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), + /*broadcast_dimensions=*/{0}); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); // Broadcast diag up to the input shape. Use an implicit broadcast (Add) // because we need to broadcast on the right. @@ -77,10 +76,10 @@ class MatrixSetDiagOp : public XlaOpKernel { if (min_dim != m) { diag_broadcast_dims.back() = rank - 1; } - diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); + diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); - auto output = builder->Select(indicator, diag, input); + auto output = xla::Select(indicator, diag, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index eaed93146460de5a6e8328432302cc75bf36a534..f4def11d08c31513aec5aad15187016a7294c2fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -30,13 +30,9 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = TriangularSolve( - ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 7e9de3ef9b245c113cc143128fe58e7e017a361c..529959dbd90b05f8860360f70e087ef225150600 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { @@ -27,21 +28,21 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, - const xla::Literal& pad_literal, + const xla::LiteralSlice& pad_literal, xla::XlaBuilder* b) { xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { - auto t_rev = b->Rev(accum, {dimno}); + auto t_rev = xla::Rev(accum, {dimno}); TF_ASSIGN_OR_RETURN(int64 lhs_padding, pad_literal.GetIntegralAsS64({dimno, 0})); TF_ASSIGN_OR_RETURN(int64 rhs_padding, pad_literal.GetIntegralAsS64({dimno, 1})); int64 dim_size = original_shape.dimensions(dimno); - auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding, - dim_size - 1, 1, dimno); - auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); - accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno); + auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding, + dim_size - 1, 1, dimno); + auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); + accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; } diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index aecaabb6dcf46bdd6ae3da929448d6370acb989b..3aed47de2603f3e187ad515d4db3f884da4c6cc8 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,11 +77,10 @@ class PackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. - reshaped_inputs[i] = - ctx->builder()->Reshape(values[i], child_shape.dim_sizes()); + reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes()); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 7c95475e7b1f02183e44f73f116a4aeb25f05c09..89fd610bc63349d008836c3c4e6ec8927c232a54 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -63,8 +64,8 @@ class PadOp : public XlaOpKernel { int before = pad_literal.Get({i, 0}); int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, - errors::InvalidArgument("Paddings must be non-negative: ", - before, " ", after)); + errors::InvalidArgument( + "Paddings must be non-negative: ", before, " ", after)); dim->set_edge_padding_low(before); dim->set_edge_padding_high(after); } @@ -74,11 +75,10 @@ class PadOp : public XlaOpKernel { if (ctx->num_inputs() == 3) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, - ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index f8e7b48a0fd94835964aea033ad33523150067b4..771dcbab21691ff1f018e4d65815cd5a53c9447a 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -113,8 +114,8 @@ class PoolingOp : public XlaOpKernel { xla::XlaBuilder* const b = ctx->builder(); auto input = XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = ctx->builder()->ReduceWindow( - input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); + auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, + stride, padding_); auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); ctx->SetOutput(0, PostProcessOutput(ctx, pooled, input_type(0), input_shape)); @@ -190,7 +191,7 @@ static xla::XlaOp AvgPoolDivideByCount( auto divisor = XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return ctx->builder()->Div(output, divisor); + return xla::Div(output, divisor); } else { // For SAME padding, the padding shouldn't be included in the // counts. We use another ReduceWindow to find the right counts. @@ -212,18 +213,18 @@ static xla::XlaOp AvgPoolDivideByCount( // Build a matrix of all 1s, with the same width/height as the input. const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = ctx->builder()->Broadcast( + auto ones = xla::Broadcast( XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto reduce = ctx->builder()->ReduceWindow( + auto reduce = xla::ReduceWindow( ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - return ctx->builder()->Div(output, counts, window_dims); + return xla::Div(output, counts, window_dims); } } @@ -347,9 +348,9 @@ class MaxPoolGradOp : public XlaOpKernel { xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); auto select = CreateScalarGeComputation(element_type, ctx->builder()); auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); - xla::XlaOp gradients = ctx->builder()->SelectAndScatter( - input, select, ksize_, stride_, xla_padding, out_backprop, init_value, - scatter); + xla::XlaOp gradients = + xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding, + out_backprop, init_value, scatter); ctx->SetOutput(0, gradients); } @@ -485,12 +486,12 @@ class AvgPoolGradOp : public XlaOpKernel { } auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); + auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones std::vector ones(num_dims(), 1LL); auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = b->ReduceWindow( + auto in_backprop = xla::ReduceWindow( XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), ksize_, @@ -614,17 +615,18 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); - auto sixteen = b->ConstantR0(16); + auto sixteen = xla::ConstantR0(b, 16); // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 - auto in_hi = b->BitcastConvertType( - b->ConvertElementType(b->ConvertElementType(input, xla::BF16), - xla::F32), + auto in_hi = xla::BitcastConvertType( + xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16), + xla::F32), xla::U32); - auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); - auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); - auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); - auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. - auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. + auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen); + auto bp_lo = + xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); // We will reduce by taking the maximal value up to 16 bits (ignoring the lo @@ -633,39 +635,41 @@ class MaxPoolGradGradOp : public XlaOpKernel { { // F32 parameters to satisfy lowering type restriction for reduce opcode. const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto lhs = rb->Parameter(0, scalar, "lhs"); - auto rhs = rb->Parameter(1, scalar, "rhs"); - auto sixteen = rb->ConstantR0(16); - auto lhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), - sixteen); - auto rhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), - sixteen); + auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs"); + auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs"); + auto sixteen = xla::ConstantR0(rb.get(), 16); + auto lhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); // Must use a F32 comparison, because S32 would not work for negatives. - rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), - rb->BitcastConvertType(rhs_criteria, xla::F32)), - lhs, rhs); + xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32), + xla::BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); } auto reduce = rb->BuildAndNoteError(); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; auto pooled_hi = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto pooled_lo = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto grads_hi = - b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); - auto grads_lo = b->ShiftRightLogical( - b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = xla::ShiftRightLogical( + xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen), sixteen); - auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add. xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); - ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type)); } protected: @@ -694,5 +698,18 @@ REGISTER_XLA_OP(Name("MaxPoolGradGradV2") .CompileTimeConstInput("strides"), MaxPool2DGradGradOp); +class MaxPool3DGradGradOp : public MaxPoolGradGradOp { + public: + explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx) + : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT), + MaxPool3DGradGradOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 661cd5923e1023eaf89a6bc4f56fcc362c8bcfb6..02293796e47063b81a9ff46c8b911461e3a5f5e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -28,82 +30,115 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); - OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), - errors::InvalidArgument("num_bits is out of range: ", num_bits_, - " with signed_input_ ", signed_input_)); } void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - // Comments taken from semantics description at - // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize - // - // ... we find m such that - // - // m = max(abs(input_min), abs(input_max)) if range_given is true, - // m = max(abs(min_elem(input)), - // abs(max_elem(input))) otherwise. + xla::PrimitiveType xla_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(data_type, &xla_type)); + xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp input_min, input_max; + + // The implementation follows + // tensorflow/core/kernels/quantize_and_dequantize_op.h closely. + xla::XlaOp min_range, max_range; if (range_given_) { - double input_min_value, input_max_value; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value)); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value)); - input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value); - input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value); + min_range = ctx->Input(1); + max_range = ctx->Input(2); } else { const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); - input_min = - b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); - input_max = - b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); + min_range = ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); + max_range = ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); } - xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max)); - - // Next, we choose our fixed-point quantization buckets, [min_fixed, - // max_fixed]. If signed_input is true, this is - // - // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1), - // (1 << (num_bits - 1)) - 1]. - // - // Otherwise, if signed_input is false, the fixed-point range is - // - // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]. - int64 min_fixed, max_fixed; + + xla::XlaOp num_bits; + if (num_bits_ < 0) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 4, + errors::Internal("Expected 4 inputs to QuantizeAndDequantize")); + num_bits = ctx->Input(3); + } else { + num_bits = xla::ConstantR0(b, num_bits_); + } + + const xla::XlaOp zero = XlaHelpers::Zero(b, data_type); + const xla::XlaOp one = XlaHelpers::One(b, data_type); + const xla::XlaOp two = XlaHelpers::FloatLiteral(b, data_type, 2.0); + const xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5); + + // Calculate the range for the simulated integer quantization: + // e.g. [-128,127] for signed = true, num_bits = 8, + // or [0, 255] for signed = false, num_bits = 8. + // We do this in floating point for hardware that does not have 64-bit + // integer support. + xla::XlaOp min_quantized, max_quantized; if (signed_input_) { - min_fixed = -((1LL << (num_bits_ - 1)) - 1); - max_fixed = (1LL << (num_bits_ - 1)) - 1; + min_quantized = + -Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + xla_type)); + max_quantized = + Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + xla_type)) - + one; } else { - min_fixed = 0; - max_fixed = (1LL << num_bits_) - 1; + min_quantized = zero; + max_quantized = Pow(two, ConvertElementType(num_bits, xla_type)) - one; } - // From this we compute our scaling factor, s: - // - // s = (max_fixed - min_fixed) / (2 * m). - xla::XlaOp s = - b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), - b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); + // Determine the maximum scaling factor that would scale + // [min_range, max_range] to not exceed [min_quantized, max_quantized], + // while keeping 0 unchanged. + xla::XlaOp scale_from_min_side = + Select(Gt(min_quantized * min_range, zero), min_quantized / min_range, + XlaHelpers::MaxFiniteValue(b, data_type)); + xla::XlaOp scale_from_max_side = + Select(Gt(max_quantized * max_range, zero), max_quantized / max_range, + XlaHelpers::MaxFiniteValue(b, data_type)); - // Now we can quantize and dequantize the elements of our tensor. An element - // e is transformed into e': - // - // e' = (e * s).round_to_nearest() / s. - xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s); + // Note: Avoids changing the side of the range that determines scale. + xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side); + xla::XlaOp scale = Select(cond, scale_from_min_side, scale_from_max_side); + xla::XlaOp inverse_scale = + Select(cond, min_range / min_quantized, max_range / max_quantized); + min_range = Select(cond, min_range, min_quantized * inverse_scale); + max_range = Select(cond, max_quantized * inverse_scale, max_range); + if (range_given_) { + // Note: The clamping here is to avoid overflow in the quantized type. + // The semantics of the op does not guarantee to clamp to the specified + // min_range and max_range - because we may have changed either min_range + // or max_range. + // No need to clamp to min_range and max_range if range_given_ == false as + // in that case they were measured from the tensor. + input = Clamp(min_range, input, max_range); + } + xla::XlaOp result = + Floor((input - min_range) * scale + half) * inverse_scale + min_range; ctx->SetOutput(0, result); } - int64 num_bits_; + protected: + int64 num_bits_ = -1; bool signed_input_; bool range_given_; }; -REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp); +class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { + public: + explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx) + : QuantizeAndDequantizeOp(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); + OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), + errors::InvalidArgument("num_bits is out of range: ", num_bits_, + " with signed_input_ ", signed_input_)); + } +}; + +REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeV2Op); +REGISTER_XLA_OP(Name("QuantizeAndDequantizeV3"), QuantizeAndDequantizeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 105be38fe26b6667e8b4ce6da92a3969cdc0c187..d5b645d70a68fafea4cb77c10dacf88f2e55ce8e 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -18,6 +18,7 @@ limitations under the License. // TODO(misard,phawkins): add tests. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -25,6 +26,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -45,8 +48,8 @@ class RandomUniformOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -78,12 +81,11 @@ class RandomShuffleOp : public XlaOpKernel { // Generate the random swaps for the indices. auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); auto swaps = - builder->RngUniform(builder->ConstantR0(0), - builder->ConstantR0(n), swaps_shape); + xla::RngUniform(xla::ConstantR0(builder, 0), + xla::ConstantR0(builder, n), swaps_shape); // Generate range(n) as the initial value for the indices to be swapped. - xla::XlaOp indices; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); + xla::XlaOp indices = xla::Iota(builder, xla::S32, n); // Swap the indices at i and swaps[i]. auto swap_body_fn = [&](xla::XlaOp i, @@ -92,17 +94,17 @@ class RandomShuffleOp : public XlaOpKernel { -> xla::StatusOr> { auto swaps = loop_vars[0]; auto indices = loop_vars[1]; - i = builder->Reshape(i, {1}); + i = xla::Reshape(i, {1}); // temp = indices[i] - auto temp = builder->DynamicSlice(indices, i, {1}); + auto temp = xla::DynamicSlice(indices, i, {1}); // swap_index = swaps[i] - auto swap_index = builder->DynamicSlice(swaps, i, {1}); + auto swap_index = xla::DynamicSlice(swaps, i, {1}); // swap_value = indices[swaps[i]] - auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); + auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); // indices[i] = indices[swaps[i]] - indices = builder->DynamicUpdateSlice(indices, swap_value, i); + indices = xla::DynamicUpdateSlice(indices, swap_value, i); // indices[swaps[i]] = temp - indices = builder->DynamicUpdateSlice(indices, temp, swap_index); + indices = xla::DynamicUpdateSlice(indices, temp, swap_index); return std::vector{swaps, indices}; }; // for i in range(n): @@ -152,7 +154,7 @@ class RandomUniformIntOp : public XlaOpKernel { auto minval = ctx->Input(1); auto maxval = ctx->Input(2); - ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape)); + ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape)); } private: @@ -178,8 +180,8 @@ class RandomStandardNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); // Normal distribution with a mean of 0 and a standard deviation of 1: - xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -205,58 +207,17 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); - auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { - return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); - }; - auto out_of_range_mask = [two_sd](xla::XlaOp candidate, - xla::XlaBuilder* b) { - xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b)); - xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b)); - return b->Or(too_large, too_small); - }; - - // The algorithm we're using is roughly: - // - // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) { - // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd - // candidate = select(out_of_range_mask, rng_normal(), candidate) - // } - std::vector initial_values = { - // The current candidate. - b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), - // The to_resample mask, where 'true' identifies a location in the - // current candidate that is out of range and must be regenerated. - b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), - // Is any element in the mask true? - b->ConstantR0(true)}; - auto condition = [&](gtl::ArraySlice values, - xla::XlaBuilder* b) -> xla::StatusOr { - // Continue while any element in the mask is true. - return values[2]; - }; - auto body = - [&](gtl::ArraySlice values, - xla::XlaBuilder* b) -> xla::StatusOr> { - xla::XlaOp candidate = values[0]; - xla::XlaOp to_resample = values[1]; - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), - candidate); - // Compute a new to_resample mask, and determine whether any value is - // still out of range. - to_resample = out_of_range_mask(candidate, b); - TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); - return std::vector{candidate, to_resample, done}; - }; - auto result = - XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()[0]); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp min_positive = + XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); + auto uniform = xla::RngUniform(min_positive, one, xla_shape); + ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); } }; -REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("TruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); } // anonymous namespace diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 08894489ac77bbbe4ddb067c06a6d031a537697d..76bd1e62aa1efd85d6ed489b9a6d22a2bacf2a8b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -98,10 +99,10 @@ class ReduceWindowOp : public XlaOpKernel { { std::unique_ptr cb = builder->CreateSubBuilder("wrapper"); - auto x = cb->Parameter(0, scalar_shape, "x"); - auto y = cb->Parameter(1, scalar_shape, "y"); - auto outputs = cb->Call(*reducer.computation, {x, y}); - cb->GetTupleElement(outputs, 0); + auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); + auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); + auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); + xla::GetTupleElement(outputs, 0); xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(context, result.status()); wrapper = std::move(result.ValueOrDie()); @@ -112,7 +113,7 @@ class ReduceWindowOp : public XlaOpKernel { padding[i] = {padding_low_[i], padding_high_[i]}; } - xla::XlaOp output = builder->ReduceWindowWithGeneralPadding( + xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), wrapper, window_dimensions_, window_strides_, padding); context->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 0f425637795e9633a8e36f921000ee2f5e25813a..d3573bac3d7641128fbfc2122336a7c4347836c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -35,7 +36,7 @@ class SumOp : public XlaReductionOp { } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } }; @@ -53,7 +54,7 @@ class ProdOp : public XlaReductionOp { void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Mul(scalar_lhs, scalar_rhs); + xla::Mul(scalar_lhs, scalar_rhs); } }; @@ -71,7 +72,7 @@ class MinOp : public XlaReductionOp { void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Min(scalar_lhs, scalar_rhs); + xla::Min(scalar_lhs, scalar_rhs); } }; @@ -88,7 +89,7 @@ class MaxOp : public XlaReductionOp { void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Max(scalar_lhs, scalar_rhs); + xla::Max(scalar_lhs, scalar_rhs); } }; @@ -105,7 +106,7 @@ class MeanOp : public XlaReductionOp { } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, @@ -113,7 +114,7 @@ class MeanOp : public XlaReductionOp { int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); - return builder->Div(reduce_output, divisor); + return xla::Div(reduce_output, divisor); } }; @@ -126,12 +127,12 @@ class AllOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0(true); + return xla::ConstantR0(builder, true); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->And(scalar_lhs, scalar_rhs); + xla::And(scalar_lhs, scalar_rhs); } }; @@ -143,12 +144,12 @@ class AnyOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0(false); + return xla::ConstantR0(builder, false); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Or(scalar_lhs, scalar_rhs); + xla::Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4fd5bfd03999a7f8b7bb081cc4b03aa1434d4c3d..14506d65c4db4cea5bd9fc037536a894aea4330e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -56,9 +57,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { // Evaluate the constant, reshaping to a 1-vector if it is a scalar. xla::Literal axes_literal; - OP_REQUIRES_OK(ctx, - ctx->ConstantInputReshaped( - 1, {axes_tensor_shape.num_elements()}, &axes_literal)); + OP_REQUIRES_OK( + ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()}, + &axes_literal)); VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "axes : " << axes_literal.ToString(); @@ -101,20 +102,20 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); - auto data = b->ConvertElementType(ctx->Input(0), type); + auto data = xla::ConvertElementType(ctx->Input(0), type); // Call virtual method to get the initial value. - auto initial = b->ConvertElementType(InitialValue(b), type); + auto initial = xla::ConvertElementType(InitialValue(b), type); // Make two scalar parameters of the desired type for the lambda. - auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); + auto rx = xla::Parameter(&r, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = xla::Parameter(&r, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); - auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); - auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; + auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index ba7d484d53d7258edaa5bc42fa116cf16e94835b..a4ba6c748a73f161ea252e2adf4050eb5dda7df5 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -34,7 +34,7 @@ class ReluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); - ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); + ctx->SetOutput(0, xla::Max(zero, ctx->Input(0))); } }; @@ -46,7 +46,7 @@ class Relu6Op : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); - ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); + ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six)); } }; @@ -59,9 +59,9 @@ class ReluGradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto pred = b->Gt(ctx->Input(1), zero); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero)); + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = xla::Gt(ctx->Input(1), zero); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero)); } }; @@ -74,12 +74,12 @@ class Relu6GradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto six = b->Broadcast( + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto six = xla::Broadcast( XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); - auto out = - b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), - ctx->Input(0), zero); + auto out = xla::Select( + xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); ctx->SetOutput(0, out); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index af4d64b159c09ed7e01017f25a2b23e58542dc3c..e0ca8dd8e27914ad60d0b97e8ac5f0b91a4fd9a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -90,8 +91,7 @@ class ReshapeOp : public XlaOpKernel { VLOG(1) << "Reshape " << input_shape.DebugString() << " " << shape.DebugString(); - ctx->SetOutput(0, - ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes())); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index a711278638444be01fb865561957702368b75114..db7ea775e23e86bdbd9259e73dfa2412ef10ac6c 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -69,8 +69,7 @@ class RetvalOp : public XlaOpKernel { xla::XlaOp output = input; if (tc.is_entry_computation()) { - output = - ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + output = xla::Reshape(input, representation_shape.dim_sizes()); } else { // The core from which a return value is returned depends on the // device assignment of the input to the retval. Since we can't change @@ -78,8 +77,8 @@ class RetvalOp : public XlaOpKernel { // introduce an operator here, even if the shape does not change. // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. - output = ctx->builder()->GetTupleElement( - ctx->builder()->Tuple({output}), 0); + output = + xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); } tc.AddRetval(index_, dtype_, shape, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 2872a3c4d49d0d269aa3d216887a5c32cd51f1c3..037c422258555289711b8754f2277d077d0cd6a7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -62,7 +63,7 @@ class ReverseOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), dimensions)); } }; @@ -100,7 +101,7 @@ class ReverseV2Op : public XlaOpKernel { x_shape.dims(), ").")); } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 5d1c05268493f4f6404c40a4092a71f1e5b3f3b9..c810456f94322acfccae18d78efa861eede4648c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -85,103 +87,96 @@ class ReverseSequenceOp : public XlaOpKernel { auto condition_builder = builder->CreateSubBuilder("reverse_sequence_condition"); { - auto param = condition_builder->Parameter(0, tuple_shape, "param"); - auto i = condition_builder->GetTupleElement(param, 0); - condition_builder->Lt( - i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, - batch_size)); + auto param = + xla::Parameter(condition_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(), + seq_lens_type, batch_size)); } auto condition = condition_builder->Build(); OP_REQUIRES_OK(context, condition.status()); auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); { - auto param = body_builder->Parameter(0, tuple_shape, "param"); - auto i = body_builder->GetTupleElement(param, 0); - auto seq_lens = body_builder->GetTupleElement(param, 1); - auto output = body_builder->GetTupleElement(param, 2); + auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + auto seq_lens = xla::GetTupleElement(param, 1); + auto output = xla::GetTupleElement(param, 2); // seq_len is the sequence length of the current batch element (rank 1) - auto seq_len = body_builder->DynamicSlice( - seq_lens, body_builder->Reshape(i, {1}), {1}); + auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto batch_element_indices = body_builder->Broadcast( - XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {input_shape.dims()}); - batch_element_indices = body_builder->DynamicUpdateSlice( - batch_element_indices, body_builder->Reshape(i, {1}), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - batch_dim_), - {1})); + auto batch_element_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + batch_element_indices = xla::DynamicUpdateSlice( + batch_element_indices, xla::Reshape(i, {1}), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, batch_dim_), + {1})); // Slice out the current batch element and pad it out in the sequence // dimension. TensorShape slice_shape = input_shape; slice_shape.set_dim(batch_dim_, 1); slice_shape.set_dim(seq_dim_, max_seq_len); - auto slice = body_builder->DynamicSlice(output, batch_element_indices, - slice_shape.dim_sizes()); + auto slice = xla::DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( slice_shape.dim_size(seq_dim_)); - slice = body_builder->Pad( - slice, XlaHelpers::Zero(body_builder.get(), input_type), - padding_config); + slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); // Now slice out the reversed sequence from its actual start. // sequence_start_indices is the offset of the start of the reversed // sequence in the input. The slice will go into the padding, however, we // will mask off these elements and replace them with elements from the // original input so their values do not matter. - auto sequence_start_indices = body_builder->Broadcast( - XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {slice_shape.dims()}); - sequence_start_indices = body_builder->DynamicUpdateSlice( + auto sequence_start_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = xla::DynamicUpdateSlice( sequence_start_indices, - body_builder->Sub(XlaHelpers::IntegerLiteral( - body_builder.get(), seq_lens_type, max_seq_len), - seq_len), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - seq_dim_), - {1})); - slice = body_builder->DynamicSlice(slice, sequence_start_indices, - slice_shape.dim_sizes()); + xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + max_seq_len), + seq_len), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, seq_dim_), + {1})); + slice = xla::DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, - batch_element_indices); + output = xla::DynamicUpdateSlice(output, slice, batch_element_indices); - body_builder->Tuple( - {body_builder->Add( - i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + xla::Tuple( + body_builder.get(), + {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)), seq_lens, output}); } auto body = body_builder->Build(); OP_REQUIRES_OK(context, body.status()); - auto loop_output = builder->While( + auto loop_output = xla::While( condition.ValueOrDie(), body.ValueOrDie(), - builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, - builder->Rev(input, {seq_dim_})})); - auto output = builder->GetTupleElement(loop_output, 2); + xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + xla::Rev(input, {seq_dim_})})); + auto output = xla::GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::XlaOp iota; - OP_REQUIRES_OK( - context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + xla::XlaOp iota = + xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); std::vector dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; - auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); // Broadcast the mask up to the input shape. - mask = - builder->Or(mask, builder->Broadcast(builder->ConstantR0(false), - input_shape.dim_sizes())); + mask = xla::Or(mask, xla::Broadcast(xla::ConstantR0(builder, false), + input_shape.dim_sizes())); - output = builder->Select(mask, output, input); + output = xla::Select(mask, output, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 1819fb543317eed15b2fe0518d74aba5c564697d..76924c6a01a44e7a723b8c8895e8decbdd466c79 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -100,7 +101,7 @@ class ScanOp : public XlaOpKernel { init = XlaHelpers::One(builder, dtype); reducer = ctx->GetOrCreateMul(dtype); } - auto output = builder->ReduceWindowWithGeneralPadding( + auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, *reducer, window_dims, window_strides, padding); output = @@ -110,12 +111,12 @@ class ScanOp : public XlaOpKernel { // of all the input elements. Slice off this extra "last" element. if (exclusive_) { if (reverse_) { - output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, - 1, axis); + output = + xla::SliceInDim(output, 1, input_shape.dim_size(axis) + 1, 1, axis); } else { output = - builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + xla::SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); } } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index f2c63b4f9083ad3c7dd7cf318dc22def1e99fa9f..14709bb6cbce4b3ae0f7ff859b0fa622c6eda293 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -103,8 +104,8 @@ class ScatterNdOp : public XlaOpKernel { updates_shape)); xla::XlaBuilder* builder = context->builder(); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - buffer_shape.dim_sizes()); + auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype), + buffer_shape.dim_sizes()); auto indices = context->Input(0); auto updates = context->Input(1); auto result = diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 664078ca16c6d5d4b57c4a8c661ad0848f30dd7d..db7e55942012142297f6a4d6afa1065eb0bb24f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -22,12 +22,19 @@ limitations under the License. namespace tensorflow { namespace { -class UnsortedSegmentSum : public XlaOpKernel { +class UnsortedSegmentReduce : public XlaOpKernel { public: - explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); } + // The initial value to initialize elements of the output to. + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; + + // A function to combine two scalars with the same index (e.g., sum). + virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) = 0; + void Compile(XlaOpKernelContext* ctx) override { // output = unsorted_segment_sum(data, indices, num_segments) // Compute a tensor such that: @@ -50,27 +57,29 @@ class UnsortedSegmentSum : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(), - errors::InvalidArgument( - "UnsortedSegmentSum requires that indices' rank be" - " less than or equal to data's rank.")); + errors::InvalidArgument(type_string(), + " requires that indices' rank be" + " less than or equal to data's rank.")); // Validate that indices.shape is a prefix of data.shape. for (int d = 0; d < indices_shape.dims(); ++d) { - OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), - errors::InvalidArgument( - "UnsortedSegmentSum requires indices shape to be prefix" - " of data_shape, but dimension ", - d, " differs ", data_shape.dim_size(d), " vs. ", - indices_shape.dim_size(d))); + OP_REQUIRES( + ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), + errors::InvalidArgument(type_string(), + " requires indices shape to be prefix" + " of data_shape, but dimension ", + d, " differs ", data_shape.dim_size(d), + " vs. ", indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), - buffer_shape.dim_sizes()); + auto buffer = + xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); - auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { - return builder->Add(a, b); + auto combiner = [this](xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) { + return Combine(a, b, builder); }; auto result = XlaScatter(buffer, /*updates=*/data, indices, @@ -79,13 +88,81 @@ class UnsortedSegmentSum : public XlaOpKernel { ctx->SetOutput(0, result.ValueOrDie()); } - private: + protected: DataType dtype_; }; +class UnsortedSegmentSum : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentSum(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::Zero(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return xla::Add(a, b); + }; +}; + REGISTER_XLA_OP( Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), UnsortedSegmentSum); +class UnsortedSegmentProd : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentProd(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::One(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return xla::Mul(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"), + UnsortedSegmentProd); + +class UnsortedSegmentMin : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMin(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::MaxFiniteValue(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return xla::Min(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"), + UnsortedSegmentMin); + +class UnsortedSegmentMax : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMax(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return XlaHelpers::MinFiniteValue(builder, dtype_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) override { + return xla::Max(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"), + UnsortedSegmentMax); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index f9f48164d63492b057d4950abfc2ca6153e44870..5c010c9df23ba6c7732d87fa014879d93ff586ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -40,8 +41,6 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); - xla::XlaBuilder* builder = ctx->builder(); - auto cond_handle = ctx->Input(0); auto then_handle = ctx->Input(1); auto else_handle = ctx->Input(2); @@ -69,14 +68,14 @@ class SelectOp : public XlaOpKernel { const auto dim_sizes = then_shape.dim_sizes(); gtl::ArraySlice bdims = dim_sizes; bdims.pop_front(); - cond_handle = builder->Broadcast(cond_handle, bdims); + cond_handle = xla::Broadcast(cond_handle, bdims); std::vector dim_order(then_shape.dims()); dim_order[0] = then_shape.dims() - 1; std::iota(dim_order.begin() + 1, dim_order.end(), 0); - cond_handle = builder->Transpose(cond_handle, dim_order); + cond_handle = xla::Transpose(cond_handle, dim_order); } - ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle)); + ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 9ce01d0d44509bbcbea18afdb4210a675834bb6d..6281d6c6533f7f49a269f5c7e52226ba0f1d29f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -45,7 +45,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->builder()->Send(ctx->Input(0), channel); + xla::Send(ctx->Input(0), channel); } REGISTER_XLA_OP(Name("XlaSend"), SendOp); @@ -76,7 +76,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel)); + ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel)); } REGISTER_XLA_OP(Name("XlaRecv"), RecvOp); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 2c31f8d90891924f6f86a54ccf548de4df87f3bd..bc3d0bf5dfe9e5af8e50a25e27db7148e05e0cfd 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -55,9 +55,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template -Status CreateRangeTensor(const xla::Literal& start_literal, - const xla::Literal& limit_literal, - const xla::Literal& delta_literal, Tensor* output) { +Status CreateRangeTensor(const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, + Tensor* output) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -67,13 +68,13 @@ Status CreateRangeTensor(const xla::Literal& start_literal, } if (delta > 0) { if (start > limit) { - return errors::InvalidArgument("Requires start <= limit when delta > 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start <= limit when delta > 0: ", start, "/", limit); } } else { if (start < limit) { - return errors::InvalidArgument("Requires start >= limit when delta < 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start >= limit when delta < 0: ", start, "/", limit); } } int64 size = diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index d59720bef742c7441ee01a954247013559bb909c..5798823cd54c66dd179e3611c0041f7c5a1ff2b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -147,7 +148,7 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); @@ -204,7 +205,7 @@ class SqueezeOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } private: @@ -221,7 +222,7 @@ class ZerosLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); } }; @@ -235,7 +236,7 @@ class OnesLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto one = XlaHelpers::One(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index be1e97bf26fa4cde1b741c8d0b843a85ce33a59c..1864584adee357ce35a3e8a38a4e3c58c356bfca 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,8 +93,7 @@ class SliceOp : public XlaOpKernel { limits.push_back(begin[i] + size[i]); } std::vector strides(begin.size(), 1); - ctx->SetOutput( - 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides)); + ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides)); } else { // `begin` is not a compile-time constant. for (int i = 0; i < input_dims; ++i) { @@ -106,8 +106,7 @@ class SliceOp : public XlaOpKernel { input_shape.dim_size(i), "], but ", "got ", size[i])); } - ctx->SetOutput( - 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size)); + ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index bbf5ee8b12186a582666121b1df5d8b7d881863e..d1c69f08b0bc85fc47c03015054dd18a65eeedec 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -47,25 +48,25 @@ class SoftmaxOp : public XlaOpKernel { const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); // Find the max in each batch, resulting in a tensor of shape [batch] - auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type), + max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - auto exp_shifted = b->Exp(shifted_logits); + auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); + auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum = XlaHelpers::ConvertElementType(b, reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) - ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim}) // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - : b->Div(exp_shifted, sum, {kBatchDim}); + : xla::Div(exp_shifted, sum, {kBatchDim}); ctx->SetOutput(0, softmax); } @@ -87,43 +88,44 @@ std::pair CrossEntropyWithLogits( xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. // Broadcasts along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); + auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); // exp(logits - max_logits) - auto exp_shifted_logits = b->Exp(shifted_logits); + auto exp_shifted_logits = xla::Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto reduce = + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) - auto log_sum_exp = b->Log(sum_exp); + auto log_sum_exp = xla::Log(sum_exp); // sum(-labels * // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - auto mul = b->Mul(b->Neg(labels), sub); + auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = xla::Mul(xla::Neg(labels), sub); auto sum = - b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) // (where the division broadcasts along the batch dimension) xla::XlaOp backprop = - b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); + xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); return {loss, backprop}; } @@ -206,16 +208,14 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // Builds a vector of {batch_size} that is 0 if the index is in range, or // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. - xla::XlaOp nan_or_zero = builder->Select( - builder->And( - builder->Le(XlaHelpers::Zero(builder, indices_type), indices), - builder->Lt(indices, XlaHelpers::IntegerLiteral( - builder, indices_type, depth))), - builder->Broadcast(XlaHelpers::Zero(builder, logits_type), - {batch_size}), - builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), - {batch_size})); - labels = builder->Add(labels, nan_or_zero, {0}); + xla::XlaOp nan_or_zero = xla::Select( + xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices), + xla::Lt(indices, XlaHelpers::IntegerLiteral( + builder, indices_type, depth))), + xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}), + xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), + {batch_size})); + labels = xla::Add(labels, nan_or_zero, {0}); xla::XlaOp loss, backprop; std::tie(loss, backprop) = diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..faaf8964ff7c40d75a493b03e6b400632117cb45 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +namespace tensorflow { +namespace { + +class XlaSortOp : public XlaOpKernel { + public: + explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + context->SetOutput(0, xla::Sort(context->Input(0))); + } +}; + +REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index ec077924b5b5af4a573c86c8d9aeb8623bd7f801..8a8525efa186ed4aa02c494f7505f6245677e96e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -73,7 +74,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, "The product of the block dimensions must be positive")); xla::XlaOp padded = - b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); // 2. Reshape `padded` to `reshaped_padded` of shape: // @@ -100,7 +101,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_padded_shape.begin() + 1 + 2 * block_rank); - xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape); + xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape); // 3. Permute dimensions of `reshaped_padded` to produce // `permuted_reshaped_padded` of shape: @@ -120,7 +121,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); xla::XlaOp permuted_reshaped_padded = - b->Transpose(reshaped_padded, permutation); + xla::Transpose(reshaped_padded, permutation); // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -140,7 +141,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), output_shape.begin() + 1 + block_rank); - xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4c5886ee2a0f63d609f79fc690f457d93e284e3e..47d282fe9ec664bbc424793e93f778ebb13c6877 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class SpaceToDepthOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -145,7 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_, block_size_, // depth] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -155,7 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e831dc30a9d3c27ec3b1494e7d8a6de836ff2a11 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -0,0 +1,88 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Operator to convert sparse representations to dense. +class SparseToDenseOp : public XlaOpKernel { + public: + explicit SparseToDenseOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + // sparse_indices + const TensorShape indices_shape = context->InputShape(0); + OP_REQUIRES(context, indices_shape.dims() <= 2, + errors::InvalidArgument( + "sparse_indices should be a scalar, vector, or matrix, " + "got shape ", + indices_shape.DebugString())); + const int64 num_elems = + indices_shape.dims() > 0 ? indices_shape.dim_size(0) : 1; + const int64 num_dims = + indices_shape.dims() > 1 ? indices_shape.dim_size(1) : 1; + + // output_shape + TensorShape output_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + OP_REQUIRES(context, output_shape.dims() == num_dims, + errors::InvalidArgument( + "output_shape has incorrect number of elements: ", + output_shape.num_elements(), " should be: ", num_dims)); + + // sparse_values + const TensorShape sparse_values_shape = context->InputShape(2); + const int64 num_values = sparse_values_shape.num_elements(); + OP_REQUIRES( + context, + sparse_values_shape.dims() == 0 || + (sparse_values_shape.dims() == 1 && num_values == num_elems), + errors::InvalidArgument("sparse_values has incorrect shape ", + sparse_values_shape.DebugString(), + ", should be [] or [", num_elems, "]")); + + // default_value + const TensorShape default_value_shape = context->InputShape(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(default_value_shape), + errors::InvalidArgument("default_value should be a scalar.")); + + xla::XlaOp indices = context->Input(0); + xla::XlaOp sparse_values = context->Input(2); + xla::XlaOp default_value = context->Input(3); + + if (sparse_values_shape.dims() == 0 && num_elems != 1) { + sparse_values = Broadcast(sparse_values, {num_elems}); + } + xla::XlaBuilder* builder = context->builder(); + auto buffer = Broadcast(default_value, output_shape.dim_sizes()); + + auto result = XlaScatter(buffer, sparse_values, indices, + /*indices_are_vectors=*/num_dims > 1, + /*combiner=*/{}, builder); + context->SetOutput(0, builder->ReportErrorOrReturn(result)); + } +}; + +REGISTER_XLA_OP(Name("SparseToDense").CompileTimeConstInput("output_shape"), + SparseToDenseOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 8958b2e7701e62d802e37a895c14b662ecf9786a..ca74cf24507e1666070751a17fb940a3ad594695 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -98,7 +99,7 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); } } }; @@ -134,7 +135,7 @@ class SplitVOp : public XlaOpKernel { errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); - // check that sizes are correct + // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; std::vector split_sizes_vec(num_split, -1); @@ -148,7 +149,7 @@ class SplitVOp : public XlaOpKernel { " number of elements as the output. Got ", split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); - // get the dimension of this split + // Get the dimension of this split. xla::Literal split_size_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); @@ -199,7 +200,7 @@ class SplitVOp : public XlaOpKernel { // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); begin[split_dim] = limits[split_dim]; } } diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 0fb05a2be7b1034d6c2e864643b69647d622ede7..591e61b4c82836bc1995cd11c4c0314c9d854e50 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -144,24 +144,25 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::XlaOp ta = b->GetTupleElement(resource->value(), 0); - xla::XlaOp index = b->GetTupleElement(resource->value(), 1); + xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0); + xla::XlaOp index = xla::GetTupleElement(resource->value(), 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( - {b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}))); + OP_REQUIRES_OK(ctx, + resource->SetValue(xla::Tuple( + b, {xla::DynamicUpdateSlice(ta, update, start_indices), + xla::Add(index, xla::ConstantR0(b, 1))}))); ctx->SetOutput(0, value); } @@ -197,27 +198,27 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); xla::XlaOp state = resource->value(); - xla::XlaOp ta = b->GetTupleElement(state, 0); - xla::XlaOp index = b->GetTupleElement(state, 1); + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); - index = b->Sub(index, b->ConstantR0(1)); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); + index = Sub(index, xla::ConstantR0(b, 1)); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index a99d4ddc7c4956f7144512a9bdf6f4c2eb0f944f..50a455b5200d159b71969f45f318a30cb618b7db 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -15,11 +15,14 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -32,17 +35,9 @@ namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, int distance) { - return builder->Or( - builder->ShiftLeft(v, builder->ConstantR0(distance)), - builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); -} - -// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than -// building XOR out of other bitwise operators. -xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x, - const xla::XlaOp& y) { - return builder->Or(builder->And(x, builder->Not(y)), - builder->And(builder->Not(x), y)); + return xla::Or( + xla::ShiftLeft(v, xla::ConstantR0(builder, distance)), + xla::ShiftRightLogical(v, xla::ConstantR0(builder, 32 - distance))); } using ThreeFry2x32State = std::array; @@ -58,22 +53,22 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, std::array ks; // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = builder->ConstantR0(0x1BD11BDA); + ks[2] = xla::ConstantR0(builder, 0x1BD11BDA); for (int i = 0; i < 2; ++i) { ks[i] = key[i]; x[i] = input[i]; - ks[2] = BitwiseXor(builder, ks[2], key[i]); + ks[2] = xla::Xor(ks[2], key[i]); } - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(x[1], ks[1]); + x[0] = xla::Add(x[0], ks[0]); + x[1] = xla::Add(x[1], ks[1]); // Performs a single round of the Threefry2x32 algorithm, with a rotation // amount 'rotation'. auto round = [builder](ThreeFry2x32State v, int rotation) { - v[0] = builder->Add(v[0], v[1]); + v[0] = xla::Add(v[0], v[1]); v[1] = RotateLeftS32(builder, v[1], rotation); - v[1] = BitwiseXor(builder, v[0], v[1]); + v[1] = xla::Xor(v[0], v[1]); return v; }; @@ -83,36 +78,36 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(1)); + x[0] = xla::Add(x[0], ks[1]); + x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0(builder, 1)); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(2)); + x[0] = xla::Add(x[0], ks[2]); + x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0(builder, 2)); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0(3)); + x[0] = xla::Add(x[0], ks[0]); + x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0(builder, 3)); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(4)); + x[0] = xla::Add(x[0], ks[1]); + x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0(builder, 4)); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(5)); + x[0] = xla::Add(x[0], ks[2]); + x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0(builder, 5)); return x; } @@ -123,8 +118,8 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, const TensorShape& shape, double minval, double maxval) { // Split the seed into two 32-bit scalars to form a key. - auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); ThreeFry2x32State key = {seed0, seed1}; const int64 size = shape.num_elements(); @@ -133,81 +128,36 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, // Fill the generator inputs with unique counter values. ThreeFry2x32State inputs; - TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); - inputs[1] = builder->Add(inputs[0], builder->ConstantR0(half_size)); + inputs[0] = xla::Iota(builder, xla::S32, half_size); + inputs[1] = xla::Add(inputs[0], xla::ConstantR0(builder, half_size)); ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); if (size_is_odd) { - outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); + outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1}); } auto bits = - builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); + xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes()); // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit // forces the random bits into the mantissa. constexpr int kFloatBits = 32; constexpr int kMantissaBits = 23; - bits = builder->Or( - builder->ShiftRightLogical( - bits, builder->ConstantR0(kFloatBits - kMantissaBits)), - builder->ConstantR0(bit_cast(1.0f))); - auto floats = builder->BitcastConvertType(bits, xla::F32); + bits = xla::Or( + xla::ShiftRightLogical( + bits, xla::ConstantR0(builder, kFloatBits - kMantissaBits)), + xla::ConstantR0(builder, bit_cast(1.0f))); + auto floats = xla::BitcastConvertType(bits, xla::F32); // We have a floating point number in the range [1.0, 2.0). // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = builder->Sub(floats, builder->ConstantR0(1.0f)); + floats = xla::Sub(floats, xla::ConstantR0(builder, 1.0f)); // Multiply and add to shift to the range [minval, maxval). - floats = builder->Mul(floats, builder->ConstantR0(maxval - minval)); - floats = builder->Add(floats, builder->ConstantR0(minval)); + floats = xla::Mul(floats, xla::ConstantR0(builder, maxval - minval)); + floats = xla::Add(floats, xla::ConstantR0(builder, minval)); return floats; } -// Approximation for the inverse error function from -// Giles, M., "Approximating the erfinv function". -// The approximation has the form: -// w = -log((1 - x) * (1 + x)) -// if ( w < 5 ) { -// w = w - 2.5 -// p = sum_{i=1}^n lq[i]*w^i -// } else { -// w = sqrt(w) - 3 -// p = sum_{i=1}^n gq[i]*w^i -// } -// return p*x -xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, - const TensorShape& shape) { - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = b->ConstantR0(1.0); - auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); - - auto lt = b->Lt(w, b->ConstantR0(5.0)); - auto coefficient = [&](int i) { - return b->Select( - lt, - b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), - shape.dim_sizes()), - b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), - shape.dim_sizes())); - }; - w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), - b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b->Add(coefficient(i), b->Mul(p, w)); - } - return b->Mul(p, x); -} - } // namespace class StatelessRandomUniformOp : public XlaOpKernel { @@ -259,8 +209,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) - auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), - ErfInvF32(builder, uniform, shape)); + auto normal = xla::Mul(xla::ConstantR0(builder, std::sqrt(2.0)), + ErfInv(uniform)); ctx->SetOutput(0, normal); } @@ -275,4 +225,37 @@ REGISTER_XLA_OP(Name("StatelessRandomNormal") .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); +class StatelessTruncatedNormalOp : public XlaOpKernel { + public: + explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const DataType dtype = output_type(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::XlaOp seed = ctx->Input(1); + xla::XlaBuilder* b = ctx->builder(); + + auto uniform = + RandomUniform(b, seed, shape, std::numeric_limits::min(), 1.0); + ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp); +}; + +REGISTER_XLA_OP(Name("StatelessTruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessTruncatedNormalOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 55254c746e5ebaf6b468c24ab59b968bf0d6260b..c2165ccd86dfa1c119790beb20af0844fb1bbda8 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,12 +93,12 @@ class StridedSliceOp : public XlaOpKernel { xla::XlaOp slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { - slice = ctx->builder()->Rev(slice, dimensions_to_reverse); + slice = xla::Rev(slice, dimensions_to_reverse); } - slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); + slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); - slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } @@ -171,7 +172,7 @@ class StridedSliceGradOp : public XlaOpKernel { xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. - grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); + grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. gtl::InlinedVector dimensions_to_reverse; @@ -204,9 +205,9 @@ class StridedSliceGradOp : public XlaOpKernel { } } if (!dimensions_to_reverse.empty()) { - grad = ctx->builder()->Rev(grad, dimensions_to_reverse); + grad = xla::Rev(grad, dimensions_to_reverse); } - grad = ctx->builder()->Pad(grad, zero, padding_config); + grad = xla::Pad(grad, zero, padding_config); ctx->SetOutput(0, grad); } @@ -306,17 +307,17 @@ class StridedSliceAssignOp : public XlaOpKernel { } if (!dimensions_to_reverse.empty()) { - rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); + rhs = xla::Rev(rhs, dimensions_to_reverse); } - rhs = ctx->builder()->Reshape(rhs, slice_dims); + rhs = xla::Reshape(rhs, slice_dims); if (lhs_shape.dims() == 0) { // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix // and remove this workaround. lhs = rhs; } else { - lhs = ctx->builder()->DynamicUpdateSlice( - lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); + lhs = xla::DynamicUpdateSlice( + lhs, rhs, xla::ConstantR1(ctx->builder(), slice_begin)); } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 9adee78a1fd1fb9a12afae83197425c328b5fe7e..2f650ce3052ee4502912891cd3f60cfaec8b1d7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -123,10 +124,9 @@ xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, const gtl::ArraySlice& update_dims, const xla::XlaOp& start_indices) { - xla::XlaOp current = - builder->DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = builder->Add(current, update); - return builder->DynamicUpdateSlice(operand, sum, start_indices); + xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); + xla::XlaOp sum = xla::Add(current, update); + return xla::DynamicUpdateSlice(operand, sum, start_indices); } class TensorArrayOp : public XlaOpKernel { @@ -162,7 +162,7 @@ class TensorArrayOp : public XlaOpKernel { ta_shape.AddDim(size); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); - value = b->Broadcast(zero, ta_shape.dim_sizes()); + value = xla::Broadcast(zero, ta_shape.dim_sizes()); } XlaContext& xc = XlaContext::Get(ctx); @@ -215,12 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); @@ -259,17 +259,17 @@ class TensorArrayReadOp : public XlaOpKernel { // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: @@ -326,7 +326,7 @@ class TensorArrayGatherOp : public XlaOpKernel { for (auto i = 1; i < ta_shape.dims(); i++) { end[i] = ta_shape.dim_size(i); } - ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + ctx->SetOutput(0, xla::Slice(ta, begin, end, strides)); return; } } @@ -391,7 +391,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = b->Add(ta, value); + ta = xla::Add(ta, value); } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -407,13 +407,13 @@ class TensorArrayScatterOp : public XlaOpKernel { // Slice out part of the value. value_starts[0] = i; value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + auto slice = xla::Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto index = xla::Slice(indices, {i}, {i + 1}, {1}); auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } } @@ -452,7 +452,7 @@ class TensorArrayConcatOp : public XlaOpKernel { auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); shape[0] *= ta_shape.dim_size(0); - ctx->SetOutput(0, b->Reshape(ta, shape)); + ctx->SetOutput(0, xla::Reshape(ta, shape)); Tensor lengths(DT_INT64, {ta_dims[0]}); auto lengths_vec = lengths.vec(); @@ -522,8 +522,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( - ta, b->Reshape(value, ta_shape.dim_sizes())))); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( + ta, xla::Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index e91075196bd8414939888e22b5483ad637487af6..c9e56942625a009fb3660f413a845547192460d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -93,9 +94,9 @@ class TileOp : public XlaOpKernel { if (one_dimension_is_broadcasted_without_multiple) { // Create a constant Zero the size of the output shape to leverage binary // operation broadcast semantics. - auto broadcasted_zero = ctx->builder()->Broadcast( + auto broadcasted_zero = xla::Broadcast( XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); - ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input)); + ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); return; } @@ -103,7 +104,7 @@ class TileOp : public XlaOpKernel { // dimension. This prepends the broadcasted dimensions, so an // input of shape [2,3,1] broadcast with multiples [5,4,3] will // end up with shape [5,4,3,2,3,1]. - auto broadcasted = ctx->builder()->Broadcast(input, multiples_array); + auto broadcasted = xla::Broadcast(input, multiples_array); // Now flatten and reshape. The broadcasted dimensions are // paired with the original dimensions so in the above example // we flatten [0,3,1,4,2,5] then reshape to [10,12,3]. @@ -112,8 +113,7 @@ class TileOp : public XlaOpKernel { flattened.push_back(i); flattened.push_back(i + output_shape.size()); } - xla::XlaOp output = - ctx->builder()->Reshape(broadcasted, flattened, output_shape); + xla::XlaOp output = xla::Reshape(broadcasted, flattened, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a1377fc38ff55146d9c8986e8163fb14c4e7294 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class TopKOp : public XlaOpKernel { + public: + explicit TopKOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); + } + + void Compile(XlaOpKernelContext* context) override { + int64 k; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(1, &k)); + OP_REQUIRES(context, k >= 0, + errors::InvalidArgument("Need k >= 0, got ", k)); + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be >= 1-D, got shape ", + input_shape.DebugString())); + OP_REQUIRES( + context, input_shape.dim_size(input_shape.dims() - 1) >= k, + errors::InvalidArgument("input must have at least k columns. Had ", + input_shape.dim_size(input_shape.dims() - 1), + ", needed ", k)); + + OP_REQUIRES( + context, input_shape.dims() == 1, + errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", + input_shape.DebugString())); + + const int64 n = input_shape.dim_size(0); + OP_REQUIRES(context, n < (1 << 16), + errors::Unimplemented( + "TopK is implemented for sizes up to 2**16, got shape ", + input_shape.DebugString())); + + xla::XlaBuilder* const b = context->builder(); + if (input_shape.dim_size(0) < k) { + k = input_shape.dim_size(0); + } + const xla::XlaOp input_bf16 = context->Input(0); + xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, n); + + // TODO(b/73891930): add a key-value sort to HLO, rather than using + // bit-packing tricks here. + + xla::XlaOp zero = xla::ConstantR0(b, 0); + + // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally + // ideal. The implications of the choice are: + // + // 0x7FFFFFFF + // 1. +0.0 > -0.0 + // 2. The elements of the inputs and outputs are bitwise identical. + // 3. The sort is unstable since a later +0.0 will appear before an earlier + // -0.0. + // + // 0x8000000 + // 1. +0.0 == -0.0 + // 2. All -0.0 in the input are replaced with +0.0 in the output. + // 3. The sort is stable. + xla::XlaOp max = xla::ConstantR0(b, 0x80000000); + xla::XlaOp index_mask = xla::ConstantR0(b, 0x0000FFFF); + xla::XlaOp value_mask = xla::ConstantR0(b, 0xFFFF0000); + + // Convert to from bf16 to f32. The lower 16-bits are zero due to the + // definition of bf16. + xla::XlaOp input_f32 = xla::ConvertElementType(input_bf16, xla::F32); + + // Negate the input to reverse sort it. The lower 16-bits are zero, because + // negating a float is just inverting the high-bit. + xla::XlaOp negative_input_f32 = xla::Neg(input_f32); + + // Convert to a sign magnitude integer. The lower 16-bits are zero, since + // bitcast convert doesn't change any bits. + xla::XlaOp negative_input_sm32 = + xla::BitcastConvertType(negative_input_f32, xla::S32); + + // Convert from sign magnitude integer to two's complement integer. The + // lower 16-bits are zero on both sides of the select. On the false side, + // the value is unchanged, and on the true side, the lower 16-bits of max + // are all zero, so the lower 16-bits of the result of the subtraction will + // also be zero. + xla::XlaOp negative_input_s32 = + xla::Select(xla::Lt(negative_input_sm32, zero), + xla::Sub(max, negative_input_sm32), negative_input_sm32); + + // In order for the Or with iota_s32 to to work properly, the lower 16-bits + // of negative_input_32 must be zero. + + // Pack elements as: + // * upper 16 bits are the value + // * lower 16 bits are the index. + xla::XlaOp packed_s32 = xla::Or(negative_input_s32, iota_s32); + + // TODO(phawkins): use a more efficient algorithm that does not require a + // full sort. + xla::XlaOp sorted_s32 = xla::Slice(xla::Sort(packed_s32), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1}); + + // Unpack the value/index. + xla::XlaOp indices_s32 = xla::And(sorted_s32, index_mask); + xla::XlaOp negative_values_s32 = xla::And(sorted_s32, value_mask); + + // Convert from two's complement integer to sign magnitude integer. + xla::XlaOp negative_values_sm32 = + xla::Select(xla::Lt(negative_values_s32, zero), + xla::Sub(max, negative_values_s32), negative_values_s32); + + xla::XlaOp negative_values_f32 = + xla::BitcastConvertType(negative_values_sm32, xla::F32); + + // Negate the values to get back the original inputs. + xla::XlaOp values_f32 = xla::Neg(negative_values_f32); + + // Convert from f32 to bf16. + xla::XlaOp values_bf16 = xla::ConvertElementType(values_f32, xla::BF16); + + context->SetOutput(0, values_bf16); + context->SetOutput(1, indices_s32); + } + + private: + bool sorted_; +}; + +REGISTER_XLA_OP( + Name("TopKV2").CompileTimeConstInput("k").TypeConstraint("T", DT_BFLOAT16), + TopKOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 34caefa050c0d58f5f7bad557286b6ed64b996ad..2e5d61e111c068a0e26dba62f29e7e268291dd1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -31,7 +31,6 @@ class ResourceApplyGradientDescent : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp handle; - xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(1); TensorShape var_shape; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); @@ -48,7 +47,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel { var_shape.DebugString(), " vs ", delta_shape.DebugString())); - handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); + handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2))); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -63,8 +62,6 @@ class ResourceApplyMomentum : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -97,14 +94,14 @@ class ResourceApplyMomentum : public XlaOpKernel { xla::XlaOp grad = ctx->Input(3); xla::XlaOp momentum = ctx->Input(4); - accum = b->Add(b->Mul(accum, momentum), grad); + accum = xla::Add(xla::Mul(accum, momentum), grad); if (use_nesterov_) { // See https://github.com/tensorflow/tensorflow/pull/2798 for an // explanation of the reparameterization used here. - var = b->Sub( - var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr))); + var = xla::Sub(var, xla::Add(xla::Mul(grad, lr), + xla::Mul(xla::Mul(accum, momentum), lr))); } else { - var = b->Sub(var, b->Mul(accum, lr)); + var = xla::Sub(var, xla::Mul(accum, lr)); } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); @@ -149,10 +146,12 @@ class ResourceApplyAdagrad : public XlaOpKernel { xla::XlaOp lr = ctx->Input(2); xla::XlaOp grad = ctx->Input(3); - accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); - var = b->Sub( - var, b->Mul(b->Mul(grad, lr), - b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + accum = + xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); + var = xla::Sub( + var, + xla::Mul(xla::Mul(grad, lr), + xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } @@ -232,12 +231,13 @@ class ResourceApplyAdam : public XlaOpKernel { xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); xla::XlaOp alpha = - b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), - b->Sub(one, beta1_power)); - m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); - v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2))); - var = - b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon))); + xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)), + xla::Sub(one, beta1_power)); + m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1))); + v = xla::Add( + v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2))); + var = xla::Sub(var, xla::Div(xla::Mul(m, alpha), + xla::Add(xla::Pow(v, half), epsilon))); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); @@ -320,16 +320,17 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::XlaOp new_ms = b->Add( - ms, - b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), - b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); + xla::XlaOp new_ms = xla::Add( + ms, xla::Mul( + xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), + ms), + xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); xla::XlaOp new_mom = - b->Add(b->Mul(mom, momentum), - b->Mul(b->Mul(grad, lr), - b->Pow(b->Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::XlaOp new_var = b->Sub(var, new_mom); + xla::Add(xla::Mul(mom, momentum), + xla::Mul(xla::Mul(grad, lr), + xla::Pow(xla::Add(new_ms, epsilon), + XlaHelpers::FloatLiteral(b, type, -0.5)))); + xla::XlaOp new_var = xla::Sub(var, new_mom); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); @@ -424,21 +425,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); xla::XlaOp grad_to_use; if (has_l2_shrinkage) { - grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); + grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var))); } else { grad_to_use = grad; } - xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two)); - xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power)); - xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); - linear = b->Add( + xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two)); + xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power)); + xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power)); + linear = xla::Add( linear, - b->Sub(grad_to_use, - b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); - xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1); - xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); - var = b->Div(b->Sub(linear_clipped, linear), quadratic); + xla::Sub(grad_to_use, + xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr), + var))); + xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1); + xla::XlaOp quadratic = + xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2)); + var = xla::Div(xla::Sub(linear_clipped, linear), quadratic); accum = new_accum; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c167642174b328a968d7f7ce1f0ad6e0ab8a7a68..6c721c48fe3af45aff5cd0bd5e74e2693faf9f97 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -32,7 +33,8 @@ namespace { class TransposeOp : public XlaOpKernel { public: - explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit TransposeOp(OpKernelConstruction* ctx, bool conjugate = false) + : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); @@ -78,19 +80,37 @@ class TransposeOp : public XlaOpKernel { errors::InvalidArgument(i, " is missing from 'perm' argument.")); } + xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - ctx->SetOutput(0, ctx->Input(0)); - return; + transposed = ctx->Input(0); + } else { + transposed = xla::Transpose(ctx->Input(0), transposed_order); } - ctx->SetOutput(0, - ctx->builder()->Transpose(ctx->Input(0), transposed_order)); + // Conjugate the transposed result if this is ConjugateTransposeOp. + if (conjugate_) { + ctx->SetOutput(0, xla::Conj(transposed)); + } else { + ctx->SetOutput(0, transposed); + } } + + private: + const bool conjugate_; +}; + +class ConjugateTransposeOp : public TransposeOp { + public: + explicit ConjugateTransposeOp(OpKernelConstruction* ctx) + : TransposeOp(ctx, /*conjugate=*/true) {} }; REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); +REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstInput("perm"), + ConjugateTransposeOp); + // InvertPermutation frequently forms part of the gradient of Transpose. // // inv = InvertPermutationOp(T p) takes a permutation of @@ -127,7 +147,7 @@ class InvertPermutationOp : public XlaOpKernel { output[d] = i; } - ctx->SetOutput(0, ctx->builder()->ConstantR1(output)); + ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 2521445e86998cb027f94838650a049c9fd7e1a3..e99691646125d45d6bc016331c984fde23f5df19 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -27,15 +27,13 @@ limitations under the License. namespace tensorflow { namespace { -// A subclass of a TlaUnaryOp must build the lambda computation that -// describes the scalar->scalar function to apply to each element of -// the input. #define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \ class NAME##Op : public XlaOpKernel { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ xla::XlaOp x = ctx->Input(0); \ xla::XlaOp y = COMPUTATION; \ ctx->SetOutput(0, y); \ @@ -43,84 +41,88 @@ namespace { }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op); -XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x)); +XLAJIT_MAKE_UNARY(ComplexAbs, xla::Abs(x)); -XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x))); +XLAJIT_MAKE_UNARY(Angle, xla::Atan2(xla::Imag(x), xla::Real(x))); -XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); +XLAJIT_MAKE_UNARY(Conj, xla::Conj(x)); // Return x if x>0, otherwise -x. -XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); +XLAJIT_MAKE_UNARY(Abs, xla::Abs(x)); // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) XLAJIT_MAKE_UNARY( Acos, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - b->Add(XlaHelpers::One(b, input_type(0)), x)))); + xla::Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + xla::Atan2(xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), + 0.5)), + xla::Add(XlaHelpers::One(b, input_type(0)), x)))); // acosh(x) = log(x + sqrt(x^2 - 1)) // = log(x + sqrt((x+1)*(x-1))) XLAJIT_MAKE_UNARY( Acosh, - b->Log(b->Add(x, - b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))), - b->Sub(x, XlaHelpers::One(b, input_type(0)))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + xla::Log(xla::Add( + x, xla::Pow(xla::Mul(xla::Add(x, XlaHelpers::One(b, input_type(0))), + xla::Sub(x, XlaHelpers::One(b, input_type(0)))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) XLAJIT_MAKE_UNARY( Asin, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), - b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), + xla::Mul( + XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + xla::Atan2(x, + xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(x, x)), XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))))); // asinh(x) = log(x + sqrt(x^2 + 1)) XLAJIT_MAKE_UNARY( Asinh, - b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), - XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + xla::Log(xla::Add( + x, xla::Pow(xla::Add(xla::Mul(x, x), XlaHelpers::One(b, input_type(0))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); -XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, XlaHelpers::One(b, input_type(0)))); // atanh(x) = 0.5 * log((1 + x) / (1 - x)) XLAJIT_MAKE_UNARY( - Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), - b->Sub(XlaHelpers::One(b, input_type(0)), x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); -XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); + Atanh, + xla::Mul(xla::Log(xla::Div(xla::Add(XlaHelpers::One(b, input_type(0)), x), + xla::Sub(XlaHelpers::One(b, input_type(0)), x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, xla::Cos(x)); XLAJIT_MAKE_UNARY(Cosh, - b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); -XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); - -XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); - -XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); -XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); -XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x), - XlaHelpers::FloatLiteral( - b, input_type(0), - std::numeric_limits::infinity()))); -XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x)); + xla::Mul(xla::Add(xla::Exp(x), xla::Exp(xla::Neg(x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); +XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); + +XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); + +XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); +XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); +XLAJIT_MAKE_UNARY(IsInf, xla::Eq(xla::Abs(x), + XlaHelpers::FloatLiteral( + b, input_type(0), + std::numeric_limits::infinity()))); +XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); // Return 1/x -XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Log, b->Log(x)); +XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Log, xla::Log(x)); -XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); +XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); -XLAJIT_MAKE_UNARY(Invert, b->Not(x)); -XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); -XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); +XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); +XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); +XLAJIT_MAKE_UNARY(Neg, xla::Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. @@ -130,35 +132,35 @@ static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); - auto round_val = b->Floor(x); - auto fraction = b->Sub(x, round_val); + auto round_val = xla::Floor(x); + auto fraction = xla::Sub(x, round_val); auto nearest_even_int = - b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); - auto is_odd = b->Eq(nearest_even_int, one); - return b->Select( - b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), - b->Add(round_val, one), round_val); + xla::Sub(round_val, xla::Mul(two, xla::Floor(xla::Mul(half, x)))); + auto is_odd = xla::Eq(nearest_even_int, one); + return xla::Select(xla::Or(xla::Gt(fraction, half), + xla::And(xla::Eq(fraction, half), is_odd)), + xla::Add(round_val, one), round_val); } XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); -XLAJIT_MAKE_UNARY(Rsqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); +XLAJIT_MAKE_UNARY(Rsqrt, xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), + -0.5))); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); + return xla::Add(half, xla::Mul(half, xla::Tanh(xla::Mul(half, x)))); } XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); +XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); XLAJIT_MAKE_UNARY(Sinh, - b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); + xla::Mul(xla::Sub(xla::Exp(x), xla::Exp(xla::Neg(x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); // softplus(x) = log(1 + exp(x)) // @@ -169,21 +171,21 @@ XLAJIT_MAKE_UNARY(Sinh, // This is equivalent to: // max(x, 0) + log1p(exp(-abs(x))) XLAJIT_MAKE_UNARY(Softplus, - b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), - b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); + xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))), + xla::Log1p(xla::Exp(xla::Neg(xla::Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, - b->Div(x, - b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0))))); + xla::Div(x, xla::Add(xla::Abs(x), + XlaHelpers::One(b, input_type(0))))); XLAJIT_MAKE_UNARY(Sqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); -XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); -XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); + xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Square, xla::Mul(x, x)); +XLAJIT_MAKE_UNARY(Tan, xla::Div(xla::Sin(x), xla::Cos(x))); +XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); -XLAJIT_MAKE_UNARY(Real, b->Real(x)); -XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); +XLAJIT_MAKE_UNARY(Real, xla::Real(x)); +XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); #undef XLAJIT_MAKE_UNARY @@ -197,14 +199,14 @@ class ErfOp : public XlaOpKernel { xla::PrimitiveType primitive_type; xla::XlaOp one = XlaHelpers::One(b, input_type(0)); xla::XlaOp x = ctx->Input(0); - xla::XlaOp abs_x = b->Abs(x); + xla::XlaOp abs_x = xla::Abs(x); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &primitive_type)); - auto y = b->Select(b->Gt(abs_x, one), - b->Sub(one, ComputeErfc(b, x, primitive_type)), - ComputeErf(b, x, primitive_type)); + auto y = + xla::Select(xla::Gt(abs_x, one), xla::Sub(one, Erfc(x, primitive_type)), + Erf(x, primitive_type)); ctx->SetOutput(0, y); } }; @@ -217,15 +219,15 @@ class ErfcOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); xla::XlaOp one = XlaHelpers::One(b, input_type(0)); xla::XlaOp x = ctx->Input(0); - xla::XlaOp abs_x = b->Abs(x); + xla::XlaOp abs_x = xla::Abs(x); xla::PrimitiveType primitive_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &primitive_type)); - auto y = b->Select(b->Lt(abs_x, one), - b->Sub(one, ComputeErf(b, x, primitive_type)), - ComputeErfc(b, x, primitive_type)); + auto y = + xla::Select(xla::Lt(abs_x, one), xla::Sub(one, Erf(x, primitive_type)), + Erfc(x, primitive_type)); ctx->SetOutput(0, y); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index f87586ba578a6138e7fb921032e1a71f8c9ac80c..0e5d58ecbaeb13571f82a1311e29dc0ba91c11ac 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -74,10 +75,9 @@ class UnpackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { start_indices[axis] = i; limit_indices[axis] = i + 1; - auto slice = ctx->builder()->Slice(input, start_indices, limit_indices, - strides); + auto slice = xla::Slice(input, start_indices, limit_indices, strides); // Reshape to drop the 'axis' dimension. - auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes()); + auto result = xla::Reshape(slice, output_shape.dim_sizes()); ctx->SetOutput(i, result); } } diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index a163fa0a5b34675e46d0d7c5f4e0ccb1e3fb18eb..febac8287350e32fccfd4cb5613f21b9a5fbcb95 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { @@ -35,12 +33,33 @@ class VarIsInitializedOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { XlaResource* variable; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); - ctx->SetOutput(0, - ctx->builder()->ConstantR0(variable->initialized())); + ctx->SetOutput( + 0, xla::ConstantR0(ctx->builder(), variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); + class ReadVariableOp : public XlaOpKernel { public: explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -77,7 +96,7 @@ class AssignAddVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Add(handle, ctx->Input(1)); + handle = xla::Add(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -93,7 +112,7 @@ class AssignSubVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Sub(handle, ctx->Input(1)); + handle = xla::Sub(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -125,29 +144,152 @@ class ResourceGatherOp : public XlaOpKernel { ctx->SetOutput(0, gather); } }; -REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), - ResourceGatherOp); +REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); -class VariableShapeOp : public XlaOpKernel { +class ResourceScatterOp : public XlaOpKernel { public: - explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + explicit ResourceScatterOp( + OpKernelConstruction* context, bool indices_are_vectors, + std::function + combiner) + : XlaOpKernel(context), + indices_are_vectors_(indices_are_vectors), + combiner_(std::move(combiner)) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + + DataType dtype = context->input_type(2); + TensorShape var_shape; + xla::XlaOp var_value; + OP_REQUIRES_OK( + context, context->ReadVariableInput(0, dtype, &var_shape, &var_value)); + + const xla::XlaOp indices = context->Input(1); + const xla::XlaOp updates = context->Input(2); + + auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_, + combiner_, builder); + OP_REQUIRES_OK(context, result.status()); + OP_REQUIRES_OK(context, + context->AssignVariable(0, dtype, result.ValueOrDie())); } - void Compile(XlaOpKernelContext* ctx) override { - DataType variable_dtype; - TensorShape shape; - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); - Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); - OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); - ctx->SetConstantOutput(0, shape_constant); + private: + const bool indices_are_vectors_; + const std::function + combiner_; +}; + +class ResourceScatterAddOp : public ResourceScatterOp { + public: + explicit ResourceScatterAddOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); } +}; +REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp); + +class ResourceScatterSubOp : public ResourceScatterOp { + public: + explicit ResourceScatterSubOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - DataType out_dtype_; + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Sub(x, y); + } }; +REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp); + +class ResourceScatterMulOp : public ResourceScatterOp { + public: + explicit ResourceScatterMulOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Mul(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp); + +class ResourceScatterDivOp : public ResourceScatterOp { + public: + explicit ResourceScatterDivOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Div(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp); + +class ResourceScatterMinOp : public ResourceScatterOp { + public: + explicit ResourceScatterMinOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Min(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp); + +class ResourceScatterMaxOp : public ResourceScatterOp { + public: + explicit ResourceScatterMaxOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Max(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp); + +class ResourceScatterUpdateOp : public ResourceScatterOp { + public: + explicit ResourceScatterUpdateOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, + /*combiner=*/{}) {} +}; +REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp); + +class ResourceScatterNdUpdateOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/{}) {} +}; +REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp); + +class ResourceScatterNdAddOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdAddOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); -REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 5467c5d9946846ff9f14ce9c5aac9e2be4b9d6ab..340165bac6a2a214d8f84d5a116a4197b1df2c7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -246,7 +246,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp init = builder->Tuple(inputs); + xla::XlaOp init = xla::Tuple(builder, inputs); VLOG(1) << "Building while loop"; @@ -255,22 +255,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { { std::unique_ptr cb = builder->CreateSubBuilder("cond_wrapper"); - auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); - auto outputs = cb->Call(*cond.computation, {inputs}); - cb->GetTupleElement(outputs, 0); + auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs"); + auto outputs = xla::Call(cb.get(), *cond.computation, {inputs}); + xla::GetTupleElement(outputs, 0); xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(ctx, result.status()); cond_wrapper = std::move(result.ValueOrDie()); } - xla::XlaOp while_result = - builder->While(cond_wrapper, *body.computation, init); + xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); // Sets non-variable outputs. for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], - builder->GetTupleElement(while_result, i)); + xla::GetTupleElement(while_result, i)); } } @@ -284,7 +283,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), builder)); + xla::GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name() << " modified: " << update.modified diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index ee7f5d510ab7a3ce7d3bbe843c5fefd362f79b7b..04c600698c7d86808238f29cbeed6aa66acaee70 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -50,6 +50,20 @@ cc_library( ], ) +cc_library( + name = "random", + srcs = ["random.cc"], + hdrs = ["random.h"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "scatter", srcs = ["scatter.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index ee0bb91a6b747ffc9e28e19dd4869a5b2cc43501..f9f3a8c8cfcbcd0a2ac853360c629d90c94db8b0 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -25,91 +26,94 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, - bool conjugate_y) { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, bool conjugate_y) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(x_shape), " vs. ", xla::ShapeUtil::HumanString(y_shape)); } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + const int ndims = xla::ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), + " transpose: ", transpose_y); + } + + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::IsZeroElementArray(x_shape) || + xla::ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::Literal::Zero(x_shape.element_type())), + dimensions); + } + + if (x_shape.element_type() == xla::C64 && conjugate_x) { + x = xla::Conj(x); + } + if (y_shape.element_type() == xla::C64 && conjugate_y) { + y = xla::Conj(y); + } + + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose + // HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; + return xla::Dot(lhs, rhs); + } + + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = builder->Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = builder->Conj(y); - } - - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; - return builder->Dot(lhs, rhs); - } - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - return builder->DotGeneral(x, y, dot_dnums); + return xla::DotGeneral(x, y, dot_dnums); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 1acc72033b05e73b0f5f88907df20cde5cfffbf0..d07a9486f18c0b8f26782123a8fba4ba228f71ee 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -43,10 +43,9 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x = false, - bool conjugate_y = false); +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 20925118bf598a6436c43bd727ce40e3abafc46c..a90178c7d9b8f7ab0f80b66962e6dfeed8be9631 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -47,178 +48,163 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, - const xla::XlaOp& a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); - - xla::XlaOp l = Zeros(builder, a_shape); - - // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = body_builder->ConstantR1(mask_vector); - auto mask_range_row = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {1, n}), major_dims); - auto mask_range_col = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - - // row = l[..., i, :i] - // select the whole i-th row, then mask out all columns past i-1 - auto zero = body_builder->ConstantR0(0); - TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, - {i, zero}, {1, n})); - auto row = body_builder->Select(body_builder->Ge(mask_range_row, i), - mask_zeros_row, l_i); - // a[..., i, i] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, - {i, i}, {1, 1})); - // np.dot(row, np.swapaxes(row, -1, -2)) - xla::XlaOp diag_dot; - TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, - // np.swapaxes(row, -1, -2))) - auto l_ii = body_builder->Pow( - body_builder->Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); - - // a[..., i+1:, i] - // select the whole i-th column, then mask out all rows above i+1 +xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); + + xla::XlaOp l = Zeros(builder, a_shape); + + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); + } + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_row = Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_col = Zeros(body_builder, col_shape); + + std::vector mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range_row = + xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = + xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = xla::ConstantR0(body_builder, 0); + auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); + auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + // a[..., i, i] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + // np.dot(row, np.swapaxes(row, -1, -2)) + auto diag_dot = BatchDot(row, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = + xla::Pow(a_ii - diag_dot, + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + + // a[..., i+1:, i] + // select the whole i-th column, then mask out all rows above i+1 + auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); + auto a_ip1i = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + auto dot = BatchDot(body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + + body_l = + DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); + + return std::vector{body_a, body_l}; + }; + TF_ASSIGN_OR_RETURN( - auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); - auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, a_0i); - - // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / - // l[..., i, i] - // The columns in [i, n] are zeroed out in `row`, so we just have to - // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], - // r.T) - TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, dot); - - auto col_update = - body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii); - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, col_update, {i})); - // Assign the diagonal after the rest of the column because otherwise the - // column assign will wrap around and overwrite the diagonal assign. - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, l_ii, {i, i})); - - return std::vector{body_a, body_l}; - }; - - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); - - return cholesky_while[1]; + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; + }); } } // namespace -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); - } - - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); - } - - // Blocked left-looking Cholesky factorization. - // Algorithm 1 from - // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only - // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = Zeros(builder, a_shape); - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - if (i > 0) { - // TODO(phawkins): consider implementing SYRK for the diagonal part of - // the panel. - // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) - TF_ASSIGN_OR_RETURN(auto lhs, - SliceInMinorDims(builder, l, {i, 0}, {n, i})); - TF_ASSIGN_OR_RETURN(auto rhs, - SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); - TF_ASSIGN_OR_RETURN(auto delta, - BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto before, - SliceInMinorDims(builder, a, {i, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN( - a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta), - {i, i})); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); } - // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - TF_ASSIGN_OR_RETURN(auto x, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, factorized, {i, i})); - - if (i + k < n) { - // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - TF_ASSIGN_OR_RETURN(auto panel, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto update, - TriangularSolve(builder, factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size)); - TF_ASSIGN_OR_RETURN( - l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for + // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::XlaOp l = Zeros(builder, a_shape); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); + auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); + auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true); + auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); + a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto factorized = CholeskyUnblocked(x); + l = UpdateSliceInMinorDims(l, factorized, {i, i}); + + if (i + k < n) { + // l[i+k:, i:i+k] = + // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + auto update = TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, + /*block_size=*/block_size); + l = UpdateSliceInMinorDims(l, update, {i + k, i}); + } } - } - return l; + return l; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 20fca7969ece2729a44933fd3ef3f87230ab6cad..0f6e0e9d152ec5daedeb9c0e355bfb9731759094 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -30,8 +30,7 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc new file mode 100644 index 0000000000000000000000000000000000000000..3dfa66029ca84fad9c511e7b32a906ee41d37812 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/random.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) { + xla::XlaBuilder* builder = uniform.builder(); + auto normal_cdf = [](double x) { + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; + }; + + const double kA = -2.0; + const double kB = 2.0; + const double kMu = 0.0; + const double kSigma = 1.0; + const double kAlpha = (kA - kMu) / kSigma; + const double kBeta = (kB - kMu) / kSigma; + const double kAlphaNormalCdf = normal_cdf(kAlpha); + const double kBetaNormalCdf = normal_cdf(kBeta); + const double kZ = kBetaNormalCdf - kAlphaNormalCdf; + + xla::XlaOp one = XlaHelpers::FloatLiteral(builder, dtype, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(builder, dtype, 2.0); + xla::XlaOp sqrt_2 = XlaHelpers::FloatLiteral(builder, dtype, std::sqrt(2.0)); + + xla::XlaOp z = XlaHelpers::FloatLiteral(builder, dtype, kZ); + xla::XlaOp alpha_normal_cdf = + XlaHelpers::FloatLiteral(builder, dtype, kAlphaNormalCdf); + + // probit(p) = sqrt(2) * erfinv(2*p-1) + auto p = xla::Add(alpha_normal_cdf, xla::Mul(z, uniform)); + auto erfinv_input = xla::Sub(xla::Mul(p, two), one); + return xla::Mul(sqrt_2, ErfInv(erfinv_input)); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h new file mode 100644 index 0000000000000000000000000000000000000000..39cbcf9c5eccffa0035ff4c5e3d9afdb129f05cc --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +// Builds an array filled with values sampled from a truncated normal +// distribution such that no values are greater than two or less than negative +// two. +// +// The "uniform" parameter must be an array of random numbers distributed in +// (0,1). +xla::XlaOp TruncatedNormal(DataType dtype, xla::XlaOp uniform); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index d5a27abb2585f699ae2719cb8a6b9a829263389e..85e3d3ab85a89615cc5a01bdb4ec8f7fec30d58e 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -97,8 +98,8 @@ xla::StatusOr XlaScatter( buffer_shape_post_axes.end()); // Construct the initial values of the loop-carried Tensors. - auto flat_indices = builder->Reshape(indices, flat_indices_shape); - auto flat_updates = builder->Reshape(updates, flat_updates_shape); + auto flat_indices = xla::Reshape(indices, flat_indices_shape); + auto flat_updates = xla::Reshape(updates, flat_updates_shape); auto init = {flat_indices, flat_updates, buffer}; // Constructs the loop body. The implementation of scatter is essentially: @@ -112,46 +113,44 @@ xla::StatusOr XlaScatter( auto updates = loop_vars[1]; auto buffer = loop_vars[2]; - auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape.element_type())); + auto zero_index = xla::ConstantLiteral( + body_builder, xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. xla::XlaOp index; - auto indices_offset = body_builder->Reshape(i, {1}); + auto indices_offset = xla::Reshape(i, {1}); if (indices_are_vectors) { - indices_offset = body_builder->Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); + indices_offset = xla::Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); - index = body_builder->DynamicSlice(indices, indices_offset, - {1, num_index_dims}); - index = body_builder->Collapse(index, {0, 1}); + index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); + index = xla::Collapse(index, {0, 1}); } else { - index = body_builder->DynamicSlice(indices, indices_offset, {1}); + index = xla::DynamicSlice(indices, indices_offset, {1}); } // Discard updates with negative indices, since some users expect this. - auto index_in_range = - body_builder->ReduceAll(body_builder->Le(zero_index, index), - body_builder->ConstantR0(true), - xla::CreateScalarAndComputation(body_builder)); + auto index_in_range = xla::ReduceAll( + xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), + xla::CreateScalarAndComputation(body_builder)); // Make the index in bounds to prevent implementation defined behavior. - index = body_builder->Max(index, zero_index); - index = body_builder->Pad( + index = xla::Max(index, zero_index); + index = xla::Pad( index, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); // Slice the i-th index from the updates array. - auto updates_offset = body_builder->Reshape(i, {1}); - updates_offset = body_builder->Pad( + auto updates_offset = xla::Reshape(i, {1}); + updates_offset = xla::Pad( updates_offset, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); std::vector flat_updates_slice_shape({1}); flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - auto update = body_builder->DynamicSlice(updates, updates_offset, - flat_updates_slice_shape); + auto update = + xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); // Unflatten the major (iteration) dimensions of the slice to their // original shape. @@ -159,20 +158,19 @@ xla::StatusOr XlaScatter( updates_slice_shape.insert(updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - update = body_builder->Reshape(update, updates_slice_shape); + update = xla::Reshape(update, updates_slice_shape); // Apply the update to the buffer. If there is a combiner, use it to merge // the current values with the update. - auto current_value = - body_builder->DynamicSlice(buffer, index, updates_slice_shape); + auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); if (combiner) { update = combiner(current_value, update, body_builder); } // Use the current value instead of the update if the index is out of // bounds. - update = body_builder->Select(index_in_range, update, current_value); + update = xla::Select(index_in_range, update, current_value); // Apply the update. - buffer = body_builder->DynamicUpdateSlice(buffer, update, index); + buffer = xla::DynamicUpdateSlice(buffer, update, index); return std::vector{indices, updates, buffer}; }; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index b4503601f94baa5a595a64c9fc81bc92d9980ac6..0d3ce129c7091e983d0826aac2a6015f4e6f7cf4 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -29,619 +30,564 @@ limitations under the License. namespace tensorflow { -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); - } - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); - } - // The batch dimensions must be equal. - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - int64 b_size = b_shape.dimensions(i); - if (a_size != b_size) { +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(a_shape), " vs. ", xla::ShapeUtil::HumanString(b_shape)); } - batch_dimensions.push_back(a_size); - } - - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", - block_size); - } - - std::map base_computations; - auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::XlaComputation& computation = base_computations[k]; - if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( - tensorflow::strings::StrCat("trsm_base_", k)); - - auto a_param = sub->Parameter( - 0, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, {k, k})), - "a"); - - std::array b_lastd; - if (left_side) { - b_lastd = {k, n}; - } else { - b_lastd = {m, k}; - } - auto b_param = sub->Parameter( - 1, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), - "b"); - - // We use a left-looking or right-looking subroutine on the block diagonal - // in the lower=true cases, while falling back to a recursive call in - // others. The left-looking and right-looking subroutines are written with - // a While loop and so yields much faster compile times. Moreover, they - // can give higher performance on smaller (sub)problems. - if (left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else if (!left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else { - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - left_side, lower, transpose_a, - conjugate_a, - /*block_size=*/1) - .status()); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } + batch_dimensions.push_back(a_size); + } - TF_ASSIGN_OR_RETURN(computation, sub->Build()); + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } - return &computation; - }; - - xla::XlaOp output = Zeros(builder, b_shape); - - // Right-looking blocked triangular solve. - // For an explanation of the algorithm, see the TRSM discussion in: - // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation - // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 - // (2008): 4. - - // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if - // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if - // conjugate_a is True. - - if (!left_side && lower == transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) - if (i + k < n) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); - } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); - } + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); } - } else if (left_side && lower != transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < m; i += block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) - if (i + k < m) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); + std::map base_computations; + auto get_base_triangular_solve = + [&](int k) -> xla::StatusOr { + xla::XlaComputation& computation = base_computations[k]; + if (computation.IsNull()) { + std::unique_ptr sub = builder->CreateSubBuilder( + tensorflow::strings::StrCat("trsm_base_", k)); + + auto a_param = xla::Parameter( + sub.get(), 0, + xla::ShapeUtil::MakeShape(b_shape.element_type(), + ConcatVectors(batch_dimensions, {k, k})), + "a"); + + std::array b_lastd; + if (left_side) { + b_lastd = {k, n}; } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); + b_lastd = {m, k}; + } + auto b_param = xla::Parameter( + sub.get(), 1, + xla::ShapeUtil::MakeShape(b_shape.element_type(), + ConcatVectors(batch_dimensions, b_lastd)), + "b"); + + // We use a left-looking or right-looking subroutine on the block + // diagonal in the lower=true cases, while falling back to a recursive + // call in others. The left-looking and right-looking subroutines are + // written with a While loop and so yields much faster compile times. + // Moreover, they can give higher performance on smaller (sub)problems. + if (left_side && lower) { + TriangularSolveLeftLooking(a_param, b_param, transpose_a, + conjugate_a); + } else if (!left_side && lower) { + TriangularSolveRightLooking(a_param, b_param, transpose_a, + conjugate_a); + } else { + TriangularSolve(a_param, b_param, left_side, lower, transpose_a, + conjugate_a, + /*block_size=*/1); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); + TF_ASSIGN_OR_RETURN(computation, sub->Build()); } - } - } else if (!left_side && lower != transpose_a) { - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + return &computation; + }; + + xla::XlaOp output = Zeros(builder, b_shape); + + // Right-looking blocked triangular solve. + // For an explanation of the algorithm, see the TRSM discussion in: + // Goto, Kazushige, and Robert Van De Geijn. "High-performance + // implementation of the level-3 BLAS." ACM Transactions on Mathematical + // Software (TOMS) 35.1 (2008): 4. + + // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if + // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if + // conjugate_a is True. + + if (!left_side && lower == transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; } + output = UpdateSliceInMinorDims(output, update, {0, i}); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) + if (i + k < n) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + } else { + a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, n}); + } + + auto b_update = BatchDot(update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + auto b_slice_2 = SliceInMinorDims(b, {0, i + k}, {m, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, i + k}); + } + } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {m, i})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } else if (left_side && lower != transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < m; i += block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); + } else { + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {i, 0}); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) + if (i + k < m) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i + k, i}, {m, i + k}); + } else { + a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, m}); + } + + auto b_update = BatchDot(a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto b_slice_2 = SliceInMinorDims(b, {i + k, 0}, {m, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {i + k, 0}); + } } - } - } else { // left_side && lower == transpose_a - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); + } else if (!left_side && lower != transpose_a) { + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = + xla::RoundUpToNearest(n, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); + } else { + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {0, i}); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) + if (i - k >= 0) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i}); + } else { + a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k}); + } + + auto b_update = BatchDot(update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {m, i}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0}); + } } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { // left_side && lower == transpose_a + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = + xla::RoundUpToNearest(m, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {i, 0}); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) + if (i - k >= 0) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i}); + } else { + a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k}); + } + + auto b_update = BatchDot(a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {i, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0}); } - - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {i, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); } } - } - return output; + return output; + }); } -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - - // Allocate the output and set its first or last row, - // output = np.zeros_like(b) - // if transpose_a: - // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] - // else: - // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::XlaOp output = Zeros(builder, b_shape); - { - auto i = transpose_a ? m - 1 : 0; - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - auto update = builder->Div(b_slice, a_slice_conj); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - } - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (m-2, output, a, b) - // else: - // init = (1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); - auto init = builder->Tuple({init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); - { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"), - 0); - if (transpose_a) { - condb->Ge(i, condb->ConstantR0(0)); - } else { - condb->Lt(i, condb->ConstantR0(m)); +xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) - // else: - // a_row = a[..., i:i+1, :i] - // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) - // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); - { - auto input_tuple = bodyb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"); - // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0(0); + // The main computation is performed in a While loop. - // We'd like to implement this: - // if transpose_a: - // a_row = T(a[..., i+1:, i:i+1]) - // result_row = (b[..., i:i+1, :] - // - np.matmul(a_row, body_out[..., i+1:, :])) - // else: - // result_row = (b[..., i:i+1, :] - // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) - // But since we can't have intermediate array sizes depend on the loop - // counter, we instead exploit the fact that we initialized the output to - // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::XlaOp a_row; - if (transpose_a) { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {zero, i}, {m, 1})); - } else { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, zero}, {1, m})); + // Allocate the output and set its first or last row, + // output = np.zeros_like(b) + // if transpose_a: + // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] + // else: + // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] + xla::XlaOp output = Zeros(builder, b_shape); + { + auto i = transpose_a ? m - 1 : 0; + auto a_slice = SliceInMinorDims(a, {i, i}, {i + 1, i + 1}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + 1, n}); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + auto update = b_slice / a_slice_conj; + output = UpdateSliceInMinorDims(output, update, {i, 0}); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN( - auto result_row_slice, - DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); - auto result_row = bodyb->Sub(result_row_slice, b_update); - - // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_elt_conj, - MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_elt_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {i, zero})); + // Construct the initial loop carry tuple, // if transpose_a: - // return (i - 1, body_out, a, b) + // init = (m-2, output, a, b) // else: - // return (i + 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? -1 : 1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); + // init = (1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = xla::ConstantR0(builder, transpose_a ? m - 2 : 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i >= 0 if transpose_a else i < m + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); + { + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"), + 0); + if (transpose_a) { + xla::Ge(i, xla::ConstantR0(condb.get(), 0)); + } else { + xla::Lt(i, xla::ConstantR0(condb.get(), m)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) + // else: + // a_row = a[..., i:i+1, :i] + // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) + // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop + // counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); + { + auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0(bodyb.get(), 0); + + // We'd like to implement this: + // if transpose_a: + // a_row = T(a[..., i+1:, i:i+1]) + // result_row = (b[..., i:i+1, :] + // - np.matmul(a_row, body_out[..., i+1:, :])) + // else: + // result_row = (b[..., i:i+1, :] + // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) + // But since we can't have intermediate array sizes depend on the loop + // counter, we instead exploit the fact that we initialized the output to + // all zeros and use that as zero-padding (doing unnecessary FLOPs). + xla::XlaOp a_row; + if (transpose_a) { + a_row = DynamicSliceInMinorDims(body_a, {zero, i}, {m, 1}); + } else { + a_row = DynamicSliceInMinorDims(body_a, {i, zero}, {1, m}); + } + auto b_update = BatchDot(a_row, body_out, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto result_row_slice = + DynamicSliceInMinorDims(body_b, {i, zero}, {1, n}); + auto result_row = result_row_slice - b_update; + + // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + auto a_elt = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + auto a_elt_conj = MaybeConjugate(a_elt, conjugate_a); + auto div_result = xla::Div(result_row, a_elt_conj); + body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {i, zero}); + + // if transpose_a: + // return (i - 1, body_out, a, b) + // else: + // return (i + 1, body_out, a, b) + auto next_i = xla::Add( + i, xla::ConstantR0(bodyb.get(), transpose_a ? -1 : 1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + }); } -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - xla::XlaOp output = Zeros(builder, b_shape); - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (0, output, a, b) - // else: - // init = (n-1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0(transpose_a ? 0 : n - 1); - auto init = builder->Tuple({init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i < n if transpose_a else i >= 0 - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); - { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"), - 0); - if (transpose_a) { - condb->Lt(i, condb->ConstantR0(n)); - } else { - condb->Ge(i, condb->ConstantR0(0)); +xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) - // else: - // a_row = a[..., :, i:i+1] - // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) - // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); - { - auto input_tuple = bodyb->Parameter( - 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0(0); - - // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, - // i:i+1]) But since we can't have intermediate array sizes depend on the - // loop counter, we instead exploit the fact that we initialized the output - // to all zeros and use that as zero-padding (doing unnecessary FLOPs). - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - // result = b - np.matmul(output, a) - auto result = bodyb->Sub(body_b, b_update); - // result_row = result[..., :, i:i+1] - TF_ASSIGN_OR_RETURN( - auto result_row, - DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); - - // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_ii_conj, - MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_ii_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {zero, i})); + // The main computation is performed in a While loop. + xla::XlaOp output = Zeros(builder, b_shape); + + // Construct the initial loop carry tuple, // if transpose_a: - // return (i + 1, body_out, a, b) + // init = (0, output, a, b) // else: - // return (i - 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? 1 : -1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); + // init = (n-1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = xla::ConstantR0(builder, transpose_a ? 0 : n - 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i < n if transpose_a else i >= 0 + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); + { + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), + 0); + if (transpose_a) { + xla::Lt(i, xla::ConstantR0(condb.get(), n)); + } else { + xla::Ge(i, xla::ConstantR0(condb.get(), 0)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) + // else: + // a_row = a[..., :, i:i+1] + // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) + // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop + // counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); + { + auto input_tuple = xla::Parameter( + bodyb.get(), 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0(bodyb.get(), 0); + + // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, + // i:i+1]) But since we can't have intermediate array sizes depend on the + // loop counter, we instead exploit the fact that we initialized the + // output to all zeros and use that as zero-padding (doing unnecessary + // FLOPs). + auto b_update = BatchDot(body_out, body_a, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + // result = b - np.matmul(output, a) + auto result = body_b - b_update; + // result_row = result[..., :, i:i+1] + auto result_row = DynamicSliceInMinorDims(result, {zero, i}, {m, 1}); + + // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + auto a_ii_conj = MaybeConjugate(a_ii, conjugate_a); + auto div_result = xla::Div(result_row, a_ii_conj); + body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {zero, i}); + + // if transpose_a: + // return (i + 1, body_out, a, b) + // else: + // return (i - 1, body_out, a, b) + auto next_i = xla::Add( + i, xla::ConstantR0(bodyb.get(), transpose_a ? 1 : -1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 540c26b2473df9e7885f4e549b3e516a3d8a0d43..80c2bc4c9c38ec101db419d48db26e67e25d169b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -57,23 +57,15 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size = 256); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size = 256); -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a); -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 87ea4763f7c2357ae179b68ade3715b24c46432f..d5ffc1498e4b6dcfbc9f24f9b5dce58fddca8ab1 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -196,11 +191,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -219,11 +213,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -242,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -267,11 +259,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); xla::Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), @@ -295,11 +286,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1., 1.5}, @@ -323,10 +313,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); + TriangularSolveLeftLooking(a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -345,10 +334,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); + TriangularSolveLeftLooking(a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); xla::Array2D expected({ {0.5, 1.0, 1.5}, diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index d9ff7e6259f3fbab8957394bff5c5670a67dd0eb..66947294954a22c00c388df2255c28b2e36b301c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -28,8 +29,8 @@ limitations under the License. namespace tensorflow { xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), + return xla::Broadcast( + xla::ConstantLiteral(builder, xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } @@ -37,19 +38,19 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, double value) { switch (type) { case xla::F16: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::BF16: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::F32: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::F64: - return builder->ConstantR0(value); + return xla::ConstantR0(builder, value); break; case xla::C64: - return builder->ConstantR0(value); + return xla::ConstantR0(builder, value); break; default: LOG(FATAL) << "unhandled element type " << type; @@ -107,134 +108,140 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, default: LOG(FATAL) << "unhandled element type " << type; } - return builder->ConstantLiteral(literal); + return xla::ConstantLiteral(builder, literal); } -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end) { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return builder->Slice(x, padded_start, padded_end, strides); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return xla::Slice(x, padded_start, padded_end, strides); + }); } -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices) { - std::vector output(indices.size() + major_dims.size()); - std::copy(major_dims.begin(), major_dims.end(), output.begin()); - std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); return output; } -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, - const gtl::ArraySlice& sizes) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); - return builder->DynamicSlice(x, padded_starts, padded_sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return xla::DynamicSlice(x, padded_starts, padded_sizes); + }); } -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = builder->ConstantR1(start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return builder->DynamicUpdateSlice(x, update, start_constant); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = xla::ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return xla::DynamicUpdateSlice(x, update, start_constant); + }); } -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(builder, x, update, padded_start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); } -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - return builder->DynamicUpdateSlice(x, update, padded_starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return xla::DynamicUpdateSlice(x, update, padded_starts); } -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = builder->Reshape(builder->ConstantR0(0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = - builder->Reshape(starts[i], {1}); - } - return builder->ConcatInDim(padded_starts, 0); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); + } + return xla::ConcatInDim(builder, padded_starts, 0); + }); } -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return builder->Transpose(x, permutation); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return xla::Transpose(x, permutation); + }); } -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? builder->Conj(x) : x; +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? xla::Conj(x) : x; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 3c120a2548576d6ad46870583ca65beea63507a3..ac5d2940ffc7a18fbbf87818d4333c77e05441de 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -33,7 +33,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. -xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. @@ -41,54 +41,43 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last two values being +// Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end); -// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices); +// Returns the concatenation of `xs` and `ys`. +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys); // Performs a dynamic slice in the minor dimensions of a Tensor. -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, const gtl::ArraySlice& sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); // Applies a complex conjugation operation if `a` is complex and `conjugate_a` // is true, otherwise returns its argument. -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate); +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 265b39402c832f8c810a74f281563b05afdf2b1b..7d0f2222a9aa3ef09cb8be20c5f9b26431c6498c 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); - auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); - TF_ASSERT_OK(result.status()); + DynamicSliceInMinorDims(a, {x, y}, {1, 1}); ComputeAndCompareR2(&builder, {{10}}, {a_data.get(), x_data.get(), y_data.get()}, @@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0(0)}, {1, 4})); + DynamicSliceInMinorDims(a, {index, xla::ConstantR0(&builder, 0)}, + {1, 4}); ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, {a_data.get(), index_data.get()}); @@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); - auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); - TF_ASSERT_OK(result.status()); + DynamicUpdateSliceInMinorDims(a, b, {x, y}); xla::Array2D expected( {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); @@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) { // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0(0)}, {1, n})); - TF_ASSERT_OK_AND_ASSIGN( - auto dot, BatchDot(&builder, l_index, row, - /*transpose_x=*/false, /*transpose_y=*/true)); + auto l_index = DynamicSliceInMinorDims( + a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 09ce594930efc0af47306590d76b322ac730f80f..7cc88f34d291f25814fba9f802c93117973120e7 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,7 +40,7 @@ xla::StatusOr> XlaWhileLoop( xla::XlaBuilder* builder) { std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = builder->GetTupleElement(tuple, i); + elements[i] = xla::GetTupleElement(tuple, i); } return elements; }; @@ -48,7 +49,8 @@ xla::StatusOr> XlaWhileLoop( std::unique_ptr cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { - auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -61,7 +63,8 @@ xla::StatusOr> XlaWhileLoop( std::unique_ptr body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { - auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -69,11 +72,11 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - body_builder->Tuple(result); + xla::Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); + auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } @@ -86,9 +89,8 @@ xla::StatusOr> XlaForEachIndex( auto while_cond_fn = [&](gtl::ArraySlice values, xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return cond_builder->Lt( - values[0], - IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); + return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, + num_iterations)); }; auto while_body_fn = [&](gtl::ArraySlice values, xla::XlaBuilder* body_builder) @@ -97,9 +99,9 @@ xla::StatusOr> XlaForEachIndex( std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(body_builder->Add( - iteration, - body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); + updated_values.push_back(xla::Add( + iteration, xla::ConstantLiteral( + body_builder, xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); TF_ASSIGN_OR_RETURN(std::vector body_outputs, @@ -112,7 +114,7 @@ xla::StatusOr> XlaForEachIndex( std::vector values; values.reserve(initial_values.size() + 1); values.push_back( - builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); + xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index db56b128375ce8ff2faf12c5d7ea256bdfab0f63..b43405a1a407b5fa98dd740c62af91e048cc9490 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -22,24 +22,6 @@ limitations under the License. namespace tensorflow { -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape( - host_tensor.dtype(), host_tensor.shape(), &literal_shape)); - - *literal = xla::Literal(literal_shape); - - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = host_tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&host_tensor); - memcpy(dst_ptr, src_ptr, total_bytes); - } - return Status::OK(); -} - Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal) { xla::Shape xla_shape; diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 74685025c1780c5c0ba56205a98786582e9191e9..ab7e861f3336097d2ea52487092f16edb5c14531 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -26,10 +26,6 @@ limitations under the License. namespace tensorflow { -// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an -// unsupported type. -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); - // Returns a BorrowingLiteral that utilizes the same underlying buffer owned by // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index bb9168fa358154f3db9dab87bacc9bf28dd16406..ace6fd1d8eeaf439509a7b75d8d986997c392e73 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") cc_library( name = "xla_ops", - srcs = [ - "dynamic_slice_ops.cc", - "functional_ops.cc", - "reduce_window_op.cc", - "sendrecv_ops.cc", - ], + srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", ], diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc deleted file mode 100644 index d6c0edbb889b1751ac9d9d47d0c9534b543196ff..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("XlaDynamicUpdateSlice") - .Input("input: T") - .Input("update: T") - .Input("indices: Tindices") - .Output("output: T") - .Attr("T: type") - .Attr("Tindices: {int32, int64}") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Wraps the XLA DynamicUpdateSlice operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice -. - -XlaDynamicUpdateSlice generates a result which is the value of the `input` -operand, with a slice update overwritten at `indices`. The shape of `update` -determines the shape of the sub-array of the result which is updated. The shape -of indices must be rank == 1, with dimension size equal to the rank of `input`. - -Handling of out-of-bounds slice indices is implementation-defined. - -input: A `Tensor` of type T. -indices: A vector of indices into `input`. Must have length equal to the rank of - `input`. -update: A `Tensor` of type T. Same rank as `input`. -output: A `Tensor` of type T. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc deleted file mode 100644 index 4a669f8e6eaf644f119f3c0a66f29d9f2c9a9d16..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -// TODO(b/37549631) setting the While Op to always be stateful is too -// conservative. -REGISTER_OP("XlaWhile") - .Input("input: T") - .Output("output: T") - .Attr("T: list(type) >= 0") - .Attr("cond: func") - .Attr("body: func") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = input; While (Cond(output)) { output = Body(output) } - -input: A list of input tensors whose types are T. -output: A list of output tensors whose types are T. -cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. -body: A function that takes a list of tensors and returns another - list of tensors. Both lists have the same types as specified by T. -)doc"); - -// TODO(b/37549631) setting the If Op to always be stateful is too -// conservative. -REGISTER_OP("XlaIf") - .Input("cond: Tcond") - .Input("inputs: Tin") - .Output("output: Tout") - .Attr("Tcond: type") - .Attr("then_branch: func") - .Attr("else_branch: func") - .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type) >= 0") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = cond ? then_branch(inputs) : else_branch(inputs). - -cond: A boolean scalar. -inputs: A list of input tensors. -output: A list of tensors returned by either then_branch(inputs) or - else_branch(inputs). The input shapes of the then_branch and - else_branch must match. -then_branch: A function takes 'inputs' and returns a list of tensors, - whose types are the same as what else_branch returns. -else_branch: A function takes 'inputs' and returns a list of tensors. - whose types are the same as what then_branch returns. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc deleted file mode 100644 index d9af982adc090ea78c711fd4656ba429c53b18c9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaReduceWindow") - .Input("input: T") - .Input("init_value: T") - .Attr("T: numbertype") - .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") - .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Wraps the XLA ReduceWindow operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . - -input: the input tensor -init_value: a scalar representing the initial value for the reduction -computation: a reducer function to apply -window_dimensions: the shape of the window -window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc deleted file mode 100644 index 7ec7b50e905a6cbdecea4543dcb87322b5a7e844..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaSend") - .Input("tensor: T") - .Attr("T: type") - .Attr("tensor_name: string") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Sends the named tensor to another XLA computation. Wraps the XLA Send operator -documented at - https://www.tensorflow.org/performance/xla/operation_semantics#send . - -tensor: The tensor to send. -tensor_name: A string key that identifies the channel. -)doc"); - -REGISTER_OP("XlaRecv") - .Output("tensor: dtype") - .Attr("dtype: type") - .Attr("tensor_name: string") - .Attr("shape: shape") - .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - TensorShape shape_attr; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); - shape_inference::ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); - c->set_output(0, s); - return Status::OK(); - }) - .Doc(R"doc( -Receives the named tensor from another XLA computation. Wraps the XLA Recv -operator documented at - https://www.tensorflow.org/performance/xla/operation_semantics#recv . - -tensor: The tensor to receive. -dtype: The type of the tensor. -tensor_name: A string key that identifies the channel. -shape: The shape of the tensor. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a59c77f5c3a309abe8f6fbab1e48455d54e8fae5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("XlaDynamicUpdateSlice") + .Input("input: T") + .Input("update: T") + .Input("indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA DynamicUpdateSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +. + +XlaDynamicUpdateSlice generates a result which is the value of the `input` +operand, with a slice update overwritten at `indices`. The shape of `update` +determines the shape of the sub-array of the result which is updated. The shape +of indices must be rank == 1, with dimension size equal to the rank of `input`. + +Handling of out-of-bounds slice indices is implementation-defined. + +input: A `Tensor` of type T. +indices: A vector of indices into `input`. Must have length equal to the rank of + `input`. +update: A `Tensor` of type T. Same rank as `input`. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + +REGISTER_OP("XlaRecv") + .Output("tensor: dtype") + .Attr("dtype: type") + .Attr("tensor_name: string") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +Receives the named tensor from another XLA computation. Wraps the XLA Recv +operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . + +tensor: The tensor to receive. +dtype: The type of the tensor. +tensor_name: A string key that identifies the channel. +shape: The shape of the tensor. +)doc"); + +REGISTER_OP("XlaReduceWindow") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("computation: func") + .Attr("window_dimensions: list(int)") + .Attr("window_strides: list(int)") + .Attr("padding_low: list(int)") + .Attr("padding_high: list(int)") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ReduceWindow operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +computation: a reducer function to apply +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +)doc"); + +REGISTER_OP("XlaSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor to another XLA computation. Wraps the XLA Send operator +documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . + +tensor: The tensor to send. +tensor_name: A string key that identifies the channel. +)doc"); + +REGISTER_OP("XlaSort") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. + +input: A `Tensor` of type T. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index e5ce65bec950fdfd38c3ca5bc62ac745ef8ca4a7..2fc47dffb8f5f16f24e3beb1ff75aeed3e857c58 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -77,4 +77,6 @@ def reduce_window(operand, recv = gen_xla_ops.xla_recv send = gen_xla_ops.xla_send +sort = gen_xla_ops.xla_sort + while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9c8e56a17e07348d3cfaaca0b5eb335295af05c3..0c98c208053b47f4f92cc46f2280271847e88b61 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -338,9 +339,9 @@ Status BuildComputation( const std::vector& arg_cores, const std::vector& retvals, const std::vector>& resources, - bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, - xla::XlaComputation* computation, int* num_computation_outputs, - int* num_nonconst_outputs, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, std::vector* outputs, std::vector* resource_updates) { std::vector elems; @@ -384,13 +385,14 @@ Status BuildComputation( const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = resource->value() != resource->initial_value(); + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. for (const auto& grad : resource->tensor_array_gradients()) { - modified = modified || - grad.second->value() != grad.second->initial_value() || - arg.tensor_array_gradients.count(grad.first) == 0; + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); @@ -415,7 +417,7 @@ Status BuildComputation( // create a tuple/get-tuple-element combination so that sharding // assignment will be placed on this value, which will cause the resource // update to be returned from the same device that provided the resource. - handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); elems.push_back(handle); } @@ -424,7 +426,9 @@ Status BuildComputation( *num_computation_outputs = elems.size(); // Builds the XLA computation. - builder->Tuple(elems); + if (always_return_tuple || elems.size() != 1) { + xla::Tuple(builder, elems); + } builder->ClearOpMetadata(); xla::StatusOr computation_status = builder->Build(); @@ -551,16 +555,16 @@ Status XlaCompiler::BuildArguments( } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } else { - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = builder->GetTupleElement(tuple, i); + arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { @@ -568,8 +572,8 @@ Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = - builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); + arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], + strings::StrCat("arg", i)); } } @@ -600,7 +604,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression.set_handle( - builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); } else { arg_expression.set_handle(arg_handles[i]); } @@ -681,7 +685,7 @@ string ValidateFunctionDef(const FunctionDef* fdef, Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, const DeviceType& device_type, const string& name) { - std::vector invalid_ops; + std::set invalid_ops; for (const Node* node : graph->nodes()) { if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; @@ -690,19 +694,19 @@ Status ValidateGraph(const Graph* graph, if (fdef) { string error_msg = ValidateFunctionDef(fdef, flib_def); if (!error_msg.empty()) { - invalid_ops.push_back( + invalid_ops.insert( strings::StrCat(node->def().op(), ":{", error_msg, "}")); } continue; } const OpDef* op_def; if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { - invalid_ops.push_back(node->def().op()); + invalid_ops.insert(node->def().op()); continue; } TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { - invalid_ops.push_back(node->def().op()); + invalid_ops.insert(node->def().op()); } } if (!invalid_ops.empty()) { @@ -767,9 +771,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), - options.return_updated_values_for_all_resources, &builder, - result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->outputs, &result->resource_updates)); + options.return_updated_values_for_all_resources, + options.always_return_tuple, &builder, result->computation.get(), + &num_computation_outputs, &num_nonconst_outputs, &result->outputs, + &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index c93850ce270502ea1df1f6469963e96e86994fa2..80593eaca5e695cd93f14d52d4af88e7624bf105 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -52,13 +52,7 @@ class XlaContext; // (kind kResource). // // Only kParameter and initialized kResource arguments become runtime parameters -// to the generated XLA computation. The XLA computation will have run-time -// parameters in the following order: -// +---------------------+-----------------------------------------+ -// | kParameter values | Initial values of kResource arguments | -// +---------------------+-----------------------------------------+ -// Within each block, the arguments are arranged by the _Arg index from which -// they were derived. +// to the generated XLA computation. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -77,10 +71,10 @@ class XlaContext; // tensors with a different shape to their representation inside the XLA // computation. // -// In both inputs and outputs, kResource values are placed the end. When +// In computation outputs, updated kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has -// identical input and output signatures. By moving variable values -// to the end of the argument list and using the +// identical input and output signatures. By passing variable values +// at the end of the argument list and using the // `return_updated_values_for_all_variables` option, we can ensure that the // input and output values of resources appear at the same positions. // @@ -175,6 +169,11 @@ class XlaCompiler { // computation. bool resolve_compile_time_constants = true; + // If 'always_return_tuple' is true, then the output of a computation will + // always be a tuple. Otherwise, a single-element output will not be wrapped + // in a tuple. + bool always_return_tuple = true; + // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; @@ -234,7 +233,8 @@ class XlaCompiler { tf2xla::HostComputeMetadata host_compute_metadata; // Resources whose values were updated by the computation, ordered - // by return value position. Resource updates follow the non-constant + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant // results in the outputs of XLA computation. std::vector resource_updates; diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 098072d33cd4eb7f7dec0ec4196b43eca0220d4a..d0b560690758a4d73c4836ad97470d52e45fc59e 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -92,7 +92,7 @@ void XlaContext::AddRetval(int retval_index, DataType type, } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal) { + const xla::LiteralSlice& literal) { VLOG(1) << "Adding retval index " << retval_index << " with non-data-dependent tensor to XLA computation"; if (retvals_.size() <= retval_index) { @@ -131,9 +131,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Max(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Max(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -145,9 +147,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Min(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Min(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -159,9 +163,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Add(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Add(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -173,9 +179,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Mul(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Mul(x, y); return b.Build().ConsumeValueOrDie(); }); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 341bf6ff1f37fa7cd81f41c02a941214067b1bd1..5960daaefd625a0b4daf00d7b8c929f3c856575f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -83,7 +83,7 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal); + const xla::LiteralSlice& literal); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index a1da176fe30ddd0d4460a51b60b2568ecc1af6aa..81bdf139f56a5e0a20fb4e1a99cb3d4afc833159 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -51,14 +51,14 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, xla::PrimitiveType xla_output_type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer, - /*dimensions_to_reduce=*/{axis}); + xla::XlaOp input_max = xla::Reduce(input, init_value, *reducer, + /*dimensions_to_reduce=*/{axis}); std::vector broadcast_dims(input_shape.dims() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = builder->ConvertElementType( - builder->Eq(input, input_max, broadcast_dims), xla_output_type); + xla::XlaOp partial_mask = xla::ConvertElementType( + xla::Eq(input, input_max, broadcast_dims), xla_output_type); // In order to make identity elements for a bitwise And, we: // Left shift the 1 to the leftmost bit, yielding 0x10...0 @@ -68,24 +68,23 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; xla::XlaOp shift_amount = XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::XlaOp full_mask = builder->ShiftRightArithmetic( - builder->ShiftLeft(partial_mask, shift_amount), shift_amount); + xla::XlaOp full_mask = xla::ShiftRightArithmetic( + xla::ShiftLeft(partial_mask, shift_amount), shift_amount); // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its // index. - xla::XlaOp iota; const int64 axis_size = input_shape.dim_size(axis); - TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla_output_type, axis_size); xla::XlaOp product = - builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); // If there are multiple maximum elements, choose the one with the highest // index. xla::XlaOp output = - builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), - *ctx->GetOrCreateMax(output_type), - /*dimensions_to_reduce=*/{axis}); + xla::Reduce(product, XlaHelpers::MinValue(builder, output_type), + *ctx->GetOrCreateMax(output_type), + /*dimensions_to_reduce=*/{axis}); *argminmax = output; return Status::OK(); } @@ -95,38 +94,75 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MinValue(type)); + return xla::ConstantLiteral(b, xla::Literal::MinValue(type)); +} + +xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + switch (type) { + case xla::F16: + return xla::ConstantR0( + b, Eigen::NumTraits::lowest()); + case xla::BF16: + return xla::ConstantR0(b, bfloat16::lowest()); + case xla::F32: + return xla::ConstantR0(b, -std::numeric_limits::max()); + case xla::F64: + return xla::ConstantR0(b, -std::numeric_limits::max()); + default: + return xla::ConstantLiteral(b, xla::Literal::MinValue(type)); + } } xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MaxValue(type)); + return xla::ConstantLiteral(b, xla::Literal::MaxValue(type)); +} + +xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) { + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); + switch (type) { + case xla::F16: + return xla::ConstantR0( + b, Eigen::NumTraits::highest()); + case xla::BF16: + return xla::ConstantR0(b, bfloat16::highest()); + case xla::F32: + return xla::ConstantR0(b, std::numeric_limits::max()); + case xla::F64: + return xla::ConstantR0(b, std::numeric_limits::max()); + default: + return xla::ConstantLiteral(b, xla::Literal::MaxValue(type)); + } } xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::Zero(type)); + return xla::ConstantLiteral(b, xla::Literal::Zero(type)); } xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::One(type)); + return xla::ConstantLiteral(b, xla::Literal::One(type)); } xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { switch (data_type) { case DT_HALF: - return b->ConstantR0( + return xla::ConstantR0( + b, static_cast(Eigen::NumTraits::epsilon())); case DT_BFLOAT16: - return b->ConstantR0(bfloat16::epsilon()); + return xla::ConstantR0(b, bfloat16::epsilon()); case DT_FLOAT: - return b->ConstantR0(std::numeric_limits::epsilon()); + return xla::ConstantR0(b, std::numeric_limits::epsilon()); case DT_DOUBLE: - return b->ConstantR0(std::numeric_limits::epsilon()); + return xla::ConstantR0(b, std::numeric_limits::epsilon()); default: LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " << DataTypeString(data_type); @@ -194,31 +230,6 @@ Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, axis, /*is_min=*/true, argmin); } -Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota) { - TensorShape linspace_shape({size}); - Tensor linspace; - switch (dtype) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(dtype)); - } - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - - *iota = builder->ConstantLiteral(linspace_literal); - return Status::OK(); -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, @@ -248,6 +259,7 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } + xla::BorrowingLiteral linspace_literal; TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); @@ -256,17 +268,19 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = builder->Eq( - indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); + xla::XlaOp one_hot_bool = xla::Eq( + indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. - *one_hot = builder->Select( - one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()), - builder->Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select(one_hot_bool, + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from + // repeated floating point additions. if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } @@ -278,7 +292,7 @@ xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); - return builder->ConvertElementType(operand, convert_to); + return xla::ConvertElementType(operand, convert_to); } } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index c3fdc5252e74363fe289eeabb2cb0d68298ee291..495bd2b8b6fb48ffeb52e186324433b8e66e3aca 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -29,13 +29,21 @@ namespace tensorflow { class XlaHelpers { public: // Returns a handle representing the minimum value of a scalar - // element of data_type. + // element of data_type. -inf for floating-point types. static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the maximum value of a scalar + // Returns a handle representing the minimum finite value of a scalar // element of data_type. + static xla::XlaOp MinFiniteValue(xla::XlaBuilder* b, DataType data_type); + + // Returns a handle representing the maximum value of a scalar + // element of data_type. inf for floating point types. static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); + // Returns a handle representing the maximum finite value of a scalar + // element of data_type. + static xla::XlaOp MaxFiniteValue(xla::XlaBuilder* b, DataType data_type); + // Returns a handle representing the zero value of a scalar // element of data_type. static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); @@ -81,10 +89,6 @@ class XlaHelpers { DataType input_type, DataType output_type, int axis, xla::XlaOp* argmin); - // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota); - // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 76c68d81af4dd9ec40fe6b1c33b03a876a0c6dc6..0eabfb3a527be35972be2860709f1658601ce3a4 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -19,7 +19,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -38,8 +41,7 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().builder() != nullptr || - expression->resource() != nullptr); + CHECK(expression->handle().valid() || expression->resource() != nullptr); VLOG(1) << "Fetched T" << expression->handle(); return expression; } @@ -48,7 +50,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK_EQ(expression->handle().builder(), nullptr); + CHECK(!expression->handle().valid()); return const_cast(expression); } @@ -67,6 +69,20 @@ TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } +DataType XlaOpKernelContext::input_type(int index) const { + return context_->input(index).dtype(); +} + +xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(input_type(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + Status XlaOpKernelContext::ConstantInput(int index, xla::Literal* constant_literal) { return ConstantInputReshaped( @@ -87,6 +103,25 @@ Status XlaOpKernelContext::ConstantInputReshaped( } const XlaExpression* expression = CastExpressionFromTensor(tensor); + auto copy_tensor_to_literal = [](const Tensor& tensor, + xla::Literal* literal) { + xla::Shape literal_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); + + *literal = xla::Literal(literal_shape); + + // memcpy over the payload ... + // TODO(phawkins): handle string types. + size_t total_bytes = tensor.TotalBytes(); + if (total_bytes > 0) { + void* dst_ptr = literal->untyped_data(); + const void* src_ptr = DMAHelper::base(&tensor); + memcpy(dst_ptr, src_ptr, total_bytes); + } + return Status::OK(); + }; + // If the tensor has a known constant value, there is no need to invoke XLA. if (expression->has_constant_value()) { Tensor temp(tensor.dtype()); @@ -95,19 +130,21 @@ Status XlaOpKernelContext::ConstantInputReshaped( // with the enclosing Tensor. return errors::Internal("Incompatible shapes in ConstantInputReshaped."); } - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } // Make sure we treat zero-element tensors as constant. if (new_shape.num_elements() == 0) { Tensor temp(tensor.dtype(), new_shape); - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } xla::XlaOp handle = expression->handle(); if (new_shape != tensor.shape()) { // Reshape the handle to the desired shape. - handle = builder()->Reshape(handle, new_shape.dim_sizes()); + handle = xla::Reshape(handle, new_shape.dim_sizes()); } // The XLA layout is specified minor to major, and TensorFlow's minor @@ -162,7 +199,8 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Converts an int32 or int64 scalar literal to an int64. -static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { +static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, + int64* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -177,7 +215,8 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { } // Converts an float32 or float64 scalar literal to a float64. -static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) { +static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, + double* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -204,7 +243,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { } // Converts an int32 or int64 1D literal to an int64 vector. -static Status LiteralToInt64Vector(const xla::Literal& literal, +static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 1) { return errors::InvalidArgument("value is not 1D"); @@ -319,8 +358,7 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, if (representation_shape == variable->shape()) { *value = variable->value(); } else { - *value = - builder()->Reshape(variable->value(), variable->shape().dim_sizes()); + *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); } return Status::OK(); } @@ -368,10 +406,11 @@ void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { const TensorShape& shape = constant.shape(); - xla::Literal literal; - OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); - xla::XlaOp handle = builder()->ConstantLiteral(literal); - CHECK_NE(handle.builder(), nullptr); + xla::BorrowingLiteral literal; + OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); + + xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); + CHECK(handle.valid()); // Make the Tensor that will refer to the expression. Tensor* output = nullptr; @@ -416,7 +455,7 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, xla::XlaOp handle) { - TF_RET_CHECK(handle.builder() != nullptr); + TF_RET_CHECK(handle.valid()); const XlaExpression* expression = CastExpressionFromTensor(context_->input(input_index)); @@ -438,7 +477,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, TensorShape representation_shape = xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { - handle = builder()->Reshape(handle, representation_shape.dim_sizes()); + handle = xla::Reshape(handle, representation_shape.dim_sizes()); } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 667dc262ca03ca716ffbf015a78fc14c7a8b7c1a..2bde2c983d0cca05558e86a36698d6f0e097705a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -67,7 +68,12 @@ class XlaOpKernelContext { int num_inputs() const { return context_->num_inputs(); } // Returns the type of input 'index'. - DataType input_type(int index) { return context_->input(index).dtype(); } + DataType input_type(int index) const; + + // Returns the type of input 'index' as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType input_xla_type(int index); // Returns the shape of input 'index'. TensorShape InputShape(int index); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 4692038b61f6871a8a16299fd4d11e963eb46a57..46785bc1f0a1279bfd67a55844fe238d9797382b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -71,16 +71,18 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_resource_types settings."; return false; } - if (!x.has_device_whitelist || !y.has_device_whitelist) { - LOG(WARNING) << "Registrations of " << x.name - << " do not both have device whitelists."; + if (!x.has_device_whitelist && !y.has_device_whitelist) { + LOG(WARNING) << "Duplicate registrations of " << x.name + << "with no device whitelists."; return false; } - for (const auto& device : x.device_whitelist) { - if (y.device_whitelist.count(device) != 0) { - LOG(WARNING) << "Multiple registrations of " << x.name << " on device " - << device; - return false; + if (x.has_device_whitelist && y.has_device_whitelist) { + for (const auto& device : x.device_whitelist) { + if (y.device_whitelist.count(device) != 0) { + LOG(WARNING) << "Multiple registrations of " << x.name << " on device " + << device; + return false; + } } } if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { @@ -157,97 +159,143 @@ void XlaOpRegistry::RegisterCompilationKernels() { registry.jit_kernels_registered_ = true; OpRegistryInterface* op_registry = OpRegistry::Global(); - for (const auto& op : registry.ops_) { - const string& op_name = op.first; - const std::unique_ptr& op_registration = op.second; - const OpDef* op_def; - Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); - if (!lookup_status.ok()) { - LOG(ERROR) << lookup_status.error_message(); - XLA_LOG_LINES( - ERROR, "Ops registered: \n" + - dynamic_cast(op_registry)->DebugString(true)); + // Order of op registration: + // The goal is to allow the co-existence of backend-specific kernels and + // generic kernels. To achieve this, we enforce the following order of + // registrations for one op: + // 1. Process op registration with device whitelists: + // this pass registers backend-specific kernels for this op. + // 2. Process op registration without device whitelists: + // this pass registers the kernels for all the other supported backends. + for (auto& ops : registry.ops_) { + const string& op_name = ops.first; + std::vector>& op_registrations = ops.second; + // Partition the op registration so that the ones with device whitelists + // precede the one without device whitelist. + std::partition(op_registrations.begin(), op_registrations.end(), + [](const std::unique_ptr& op_reg) { + return op_reg->has_device_whitelist; + }); + + // Collect a set of backend registered by ops with device whitelists. + // The op registration without whitelists will register a generic kernel + // for all other backends not in this set. + std::unordered_set whitelisted_backend; + for (auto& op_registration : op_registrations) { + if (op_registration->has_device_whitelist) { + whitelisted_backend.insert(op_registration->device_whitelist.begin(), + op_registration->device_whitelist.end()); + } } - TF_CHECK_OK(lookup_status); - std::unordered_set type_attrs; - for (const OpDef::AttrDef& attr_def : op_def->attr()) { - if (attr_def.type() == "type" || attr_def.type() == "list(type)") { - type_attrs.insert(attr_def.name()); + for (auto& op_registration : op_registrations) { + const OpDef* op_def; + Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + if (!lookup_status.ok()) { + LOG(ERROR) << lookup_status.error_message(); + XLA_LOG_LINES( + ERROR, + "Ops registered: \n" + + dynamic_cast(op_registry)->DebugString(true)); } - } + TF_CHECK_OK(lookup_status); - // Checks there are no type constraints referring to unknown attributes. - for (const auto& constraint : op_registration->type_constraints) { - if (type_attrs.find(constraint.first) == type_attrs.end()) { - LOG(FATAL) << "Unknown type attribute " << constraint.first - << " in XLA op registration for " << op_name; + std::unordered_set type_attrs; + for (const OpDef::AttrDef& attr_def : op_def->attr()) { + if (attr_def.type() == "type" || attr_def.type() == "list(type)") { + type_attrs.insert(attr_def.name()); + } } - } - for (auto& backend : registry.backends_) { - // If the operator has a device whitelist, only register on whitelisted - // devices. - if (op_registration->has_device_whitelist && - op_registration->device_whitelist.find(backend.first) == - op_registration->device_whitelist.end()) { - continue; + // Checks there are no type constraints referring to unknown attributes. + for (const auto& constraint : op_registration->type_constraints) { + if (type_attrs.find(constraint.first) == type_attrs.end()) { + LOG(FATAL) << "Unknown type attribute " << constraint.first + << " in XLA op registration for " << op_name; + } } - std::unique_ptr kdef(new KernelDef); - kdef->set_op(op_registration->name); - kdef->set_device_type(backend.first); - - // Constrain each type attribute to the intersection of: - // a) the types supported by the backend, and - // b) the types allowed by the OpDef, and - // c) the type constraints. - for (const string& type_attr : type_attrs) { - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name(type_attr); - auto* allowed_values = - attr_constraint->mutable_allowed_values()->mutable_list(); - - const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); - const auto* op_def_allowed_types = - op_def_attr.has_allowed_values() - ? &op_def_attr.allowed_values().list().type() - : nullptr; - auto constraint_it = op_registration->type_constraints.find(type_attr); - const std::set* type_constraints = - constraint_it != op_registration->type_constraints.end() - ? &constraint_it->second - : nullptr; - for (DataType dtype : backend.second.supported_types) { - // Filter out types that aren't allowed by the OpDef. - if (op_def_allowed_types != nullptr && - std::find(op_def_allowed_types->begin(), - op_def_allowed_types->end(), - dtype) == op_def_allowed_types->end()) { - continue; + for (auto& backend : registry.backends_) { + // If the operator has a device whitelist, only register on whitelisted + // devices. + if (op_registration->has_device_whitelist && + op_registration->device_whitelist.find(backend.first) == + op_registration->device_whitelist.end()) { + continue; + } + + // If the operator does NOT has a device whitelist, skip all devices + // that has already been registered. + if (!op_registration->has_device_whitelist && + whitelisted_backend.find(backend.first) != + whitelisted_backend.end()) { + continue; + } + + std::unique_ptr kdef(new KernelDef); + kdef->set_op(op_registration->name); + kdef->set_device_type(backend.first); + + // Constrain each type attribute to the intersection of: + // a) the types supported by the backend, and + // b) the types allowed by the OpDef, and + // c) the type constraints. + bool unsatisfiable_type_constraint = false; + for (const string& type_attr : type_attrs) { + KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); + attr_constraint->set_name(type_attr); + auto* allowed_values = + attr_constraint->mutable_allowed_values()->mutable_list(); + + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = + op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; + for (DataType dtype : backend.second.supported_types) { + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; + } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; + } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } - // Filter out types based on the type constraints. - if (type_constraints != nullptr && - type_constraints->find(dtype) == type_constraints->end()) { - continue; + if (op_registration->allow_resource_types) { + allowed_values->add_type(DT_RESOURCE); + } + // Don't build KernelDefs that have unsatisfiable type constraints. + if (allowed_values->type().empty()) { + unsatisfiable_type_constraint = true; + break; } - // Passed all the filters, this type is allowed. - allowed_values->add_type(dtype); } - if (op_registration->allow_resource_types) { - allowed_values->add_type(DT_RESOURCE); + if (unsatisfiable_type_constraint) continue; + + if (backend.second.op_filter != nullptr && + !backend.second.op_filter(kdef.get())) { + continue; } + VLOG(2) << "XLA op registration: device: " << backend.first + << " op: " << op_name; + registry.kernel_registrars_.emplace_back( + new kernel_factory::OpKernelRegistrar( + new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); + backend.second.kernel_defs.push_back(std::move(kdef)); } - if (backend.second.op_filter != nullptr && - !backend.second.op_filter(kdef.get())) { - continue; - } - VLOG(2) << "XLA op registration: device: " << backend.first - << " op: " << op_name; - registry.kernel_registrars_.emplace_back( - new kernel_factory::OpKernelRegistrar( - new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); - backend.second.kernel_defs.push_back(std::move(kdef)); } } } @@ -265,12 +313,12 @@ std::vector XlaOpRegistry::DeviceKernels( << "Unknown backend " << compilation_device_name; for (const std::unique_ptr& k : it->second.kernel_defs) { auto op_iter = registry.ops_.find(k->op()); - CHECK(op_iter != registry.ops_.end()); + CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty()); // The test in IsCompatible ensures that if there are multiple matching // registrations for this op name, they all have the same value of // compilation_only, so only the first match needs to be tested. if (include_compilation_only_kernels || - !op_iter->second->compilation_only) { + !op_iter->second.front()->compilation_only) { kernels.push_back(k.get()); } } @@ -282,10 +330,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); - if (it == registry.ops_.end()) { + if (it == registry.ops_.end() || it->second.empty()) { return nullptr; } - return &it->second->compile_time_constant_inputs; + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // compile_time_constant_inputs, so only the first match is returned. + return &it->second.front()->compile_time_constant_inputs; } std::vector XlaOpRegistry::BackendNames() { @@ -378,16 +429,15 @@ XlaOpRegistrar::XlaOpRegistrar( std::unique_ptr registration) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); mutex_lock lock(registry.mutex_); - auto existing_ops = registry.ops_.equal_range(registration->name); - for (auto existing = existing_ops.first; existing != existing_ops.second; - ++existing) { - if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) { + auto& existing_ops = registry.ops_[registration->name]; + for (auto& existing : existing_ops) { + if (!XlaOpRegistry::IsCompatible(*existing, *registration)) { LOG(FATAL) << "XLA op registration " << registration->name << " is incompatible with existing registration of the same name."; } } - registry.ops_.emplace(registration->name, std::move(registration)); + existing_ops.emplace_back(std::move(registration)); } XlaBackendRegistrar::XlaBackendRegistrar( diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index e255b01dd7fdcb095c7992d4352d2d9bb7d36ac3..2d4593ea4999ad6d8cd0f0e2eec9c6d69c3020b8 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -203,7 +203,7 @@ class XlaOpRegistry { // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. - std::unordered_multimap> ops_ + std::unordered_map>> ops_ GUARDED_BY(mutex_); // Have we already registered the JIT kernels on the JIT devices? diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b3b15b1af7636fddd4c29477cbfe6f9761f2c47 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// This test is to verify the correctness of XLA op registration with specific +// backend overrides. + +// A dummy backend-specific OpKernel for CPU. +class DummyCPUOp : public XlaOpKernel { + public: + explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +// A dummy generic OpKernel for all backends. +class DummyGenericOp : public XlaOpKernel { + public: + explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +REGISTER_OP("DummyDuplicateOp") + .Attr("T: {float, int32}") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); + +// Register the DummyCPUOp kernel for CPU with type INT32. +REGISTER_XLA_OP(Name("DummyDuplicateOp") + .Device(DEVICE_CPU_XLA_JIT) + .TypeConstraint("T", DT_INT32), + DummyCPUOp); +// Register the DummyGeneric kernel for all registered device (except CPU since +// it is already registered), with type FLOAT. +REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT), + DummyGenericOp); + +// Test the correctness of registered kernels. The kernel registered for CPU +// should have type INT32 while all other kernels should have type FLOAT. +TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels().kernel(); + for (const auto& kernels : registered_kernels) { + if (kernels.op() == "DummyDuplicateOp") { + EXPECT_EQ(kernels.constraint_size(), 1); + EXPECT_EQ(kernels.constraint(0).name(), "T"); + if (kernels.device_type() == "XLA_CPU_JIT") { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_INT32); + } else { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_FLOAT); + } + } + } +} + +// A dummy generic OpKernel for all backends. +class DummyInfeasibleTypeConstraintOp : public XlaOpKernel { + public: + explicit DummyInfeasibleTypeConstraintOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + LOG(FATAL) << "unreachable"; + } +}; + +REGISTER_OP("DummyInfeasibleTypeConstraintOp") + .Attr("T: {float, string}") + .Input("input: T") + .Output("output: T") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); +REGISTER_XLA_OP( + Name("DummyInfeasibleTypeConstraintOp").TypeConstraint("T", DT_STRING), + DummyInfeasibleTypeConstraintOp); + +TEST(XlaOpRegistryTest, OpWithInfeasibleTypeConstraintIsNotRegistered) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels().kernel(); + for (const auto& kernels : registered_kernels) { + // The operator should not be registered. + EXPECT_NE(kernels.op(), "DummyInfeasibleTypeConstraintOp"); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 540c65c597f20d5bb26494e56c09ff2187cfb0db..baea8149658ec0849ebb570931ca68518ec5284e 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -89,16 +90,16 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } switch (kind_) { case kVariable: { - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - shape_.dim_sizes()); + value_ = + xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes()); break; } case kTensorArray: { TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()); + value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); break; } case kStack: { @@ -106,9 +107,9 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); value_ = - builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()), - builder->ConstantR0(0)}); + xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + xla::ConstantR0(builder, 0)}); break; } @@ -130,8 +131,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - xla::XlaOp gradient_value = builder->Broadcast( - XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); + xla::XlaOp gradient_value = + xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/strings::StrCat("TensorArrayGrad: ", name_), @@ -152,7 +153,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { for (const auto& gradient : tensor_array_gradients_) { elems.push_back(gradient.second->value_); } - *pack = builder->Tuple(elems); + *pack = xla::Tuple(builder, elems); } return Status::OK(); } @@ -168,7 +169,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, } else { TF_RET_CHECK(kind_ == kTensorArray); int pos = 0; - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!initialized()) { initial_value_ = v; } @@ -178,7 +179,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, XlaResource* gradient; TF_RETURN_IF_ERROR( GetOrCreateTensorArrayGradient(source, builder, &gradient)); - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!gradient->initialized()) { gradient->initial_value_ = v; } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 9ce36d1aa7622334b2acfbe9aa85d7419c4772ed..4de18a77887496d30e3b1407ecd9042e619653af 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -75,7 +75,7 @@ class XlaResource { const xla::XlaOp& initial_value() const { return initial_value_; } // A variable is initialized if it has a value. - bool initialized() const { return value_.builder() != nullptr; } + bool initialized() const { return value_.valid(); } // Sets the type and shape of the resource. The type and shape of a resource // must not change once the variable has been initialized. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4525197146b7f29f405650bdb08e5946cbce8114..03e542855ba0e3ae81e0b754eb319cadbd5079ba 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -142,30 +142,15 @@ cc_library( cc_library( name = "statusor", - srcs = ["statusor.cc"], hdrs = [ "statusor.h", - "statusor_internals.h", ], visibility = ["//visibility:public"], deps = [ ":status", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "statusor_test", - size = "small", - srcs = ["statusor_test.cc"], - deps = [ - ":statusor", - ":test", - ":types", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/stream_executor", ], ) @@ -175,6 +160,7 @@ cc_library( hdrs = [ "iterator_util.h", "map_util.h", + "overflow_util.h", "ptr_util.h", "util.h", ], @@ -250,7 +236,7 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d49d959a6c8112d3701857a70cecb24701c7b6d9..273fa1737106c61c337edf7a5ffaca0063496d7b 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -13,6 +13,12 @@ filegroup( ]), ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], @@ -28,6 +34,32 @@ cc_library( ], ) +cc_library( + name = "numeric", + srcs = ["numeric.cc"], + hdrs = ["numeric.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + ], +) + +xla_test( + name = "numeric_test", + srcs = ["numeric_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 639f85737f0173f47d494f366b220ab60e09629e..8c314fa61bbd67774c91e7e34e93730dbe77eb8d 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -42,8 +42,8 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, } const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); + auto lhs = Parameter(b.get(), 0, scalar, "lhs"); + auto rhs = Parameter(b.get(), 1, scalar, "rhs"); generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } @@ -55,7 +55,7 @@ XlaComputation CreateScalarAddComputation(PrimitiveType type, return CreateScalarComputation( "add", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Add(lhs, rhs); + return Add(lhs, rhs); }); } @@ -64,17 +64,15 @@ XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, return CreateScalarComputation( "mul", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Mul(lhs, rhs); + return Mul(lhs, rhs); }); } XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder) { - return CreateScalarComputation( - "ge", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Ge(lhs, rhs); - }); + return CreateScalarComputation("ge", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, + const XlaOp& rhs) { return Ge(lhs, rhs); }); } XlaComputation CreateScalarMaxComputation(PrimitiveType type, @@ -82,7 +80,7 @@ XlaComputation CreateScalarMaxComputation(PrimitiveType type, return CreateScalarComputation( "max", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Max(lhs, rhs); + return Max(lhs, rhs); }); } @@ -91,7 +89,7 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type, return CreateScalarComputation( "min", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Min(lhs, rhs); + return Min(lhs, rhs); }); } @@ -99,32 +97,32 @@ XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { return CreateScalarComputation( "and", PRED, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->And(lhs, rhs); + return And(lhs, rhs); }); } XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { - return CreateScalarComputation( - "or", PRED, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Or(lhs, rhs); - }); + return CreateScalarComputation("or", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, + const XlaOp& rhs) { return Or(lhs, rhs); }); } -StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { - auto f = builder->ConstantR0(false); - XlaComputation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, - builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); +XlaOp Any(XlaOp predicates) { + XlaBuilder* builder = predicates.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto f = ConstantR0(builder, false); + XlaComputation logical_or = CreateScalarOrComputation(builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, + builder->GetShape(predicates)); + std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return Reduce(predicates, f, logical_or, all_dimensions); + }); } namespace { -xla::XlaOp FloatLiteral(xla::XlaBuilder* b, PrimitiveType data_type, - float value) { - return b->ConvertElementType(b->ConstantR0(value), data_type); +XlaOp FloatLiteral(XlaBuilder* b, PrimitiveType data_type, float value) { + return ConvertElementType(ConstantR0(b, value), data_type); } // Polynomials for computing erf/erfc. Originally from cephes. @@ -165,44 +163,91 @@ std::array kErfUCoefficient = { // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, - tensorflow::gtl::ArraySlice coefficients, - PrimitiveType data_type) { - xla::XlaOp poly = FloatLiteral(b, data_type, 0.0); +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients, + PrimitiveType data_type) { + XlaBuilder* b = x.builder(); + XlaOp poly = FloatLiteral(b, data_type, 0.0); for (float c : coefficients) { - poly = b->Add(b->Mul(poly, x), FloatLiteral(b, data_type, c)); + poly = Add(Mul(poly, x), FloatLiteral(b, data_type, c)); } return poly; } // Compute an approximation of the error function complement (1 - erf(x)). -xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type) { - xla::XlaOp zero = FloatLiteral(b, data_type, 0.0); - xla::XlaOp two = FloatLiteral(b, data_type, 2.0); - xla::XlaOp eight = FloatLiteral(b, data_type, 8.0); +XlaOp Erfc(XlaOp x, PrimitiveType data_type) { + XlaBuilder* b = x.builder(); + XlaOp zero = FloatLiteral(b, data_type, 0.0); + XlaOp two = FloatLiteral(b, data_type, 2.0); + XlaOp eight = FloatLiteral(b, data_type, 8.0); - xla::XlaOp abs_x = b->Abs(x); - xla::XlaOp z = b->Exp(b->Mul(b->Neg(x), x)); + XlaOp abs_x = Abs(x); + XlaOp z = Exp(Mul(Neg(x), x)); - xla::XlaOp pp = EvaluatePolynomial(b, abs_x, kErfcPCoefficient, data_type); - xla::XlaOp pq = EvaluatePolynomial(b, abs_x, kErfcQCoefficient, data_type); - xla::XlaOp pr = EvaluatePolynomial(b, abs_x, kErfcRCoefficient, data_type); - xla::XlaOp ps = EvaluatePolynomial(b, abs_x, kErfcSCoefficient, data_type); + XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient, data_type); + XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient, data_type); + XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient, data_type); + XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient, data_type); - xla::XlaOp y = b->Select(b->Lt(abs_x, eight), b->Div(b->Mul(z, pp), pq), - b->Div(b->Mul(z, pr), ps)); + XlaOp y = Select(Lt(abs_x, eight), Div(Mul(z, pp), pq), Div(Mul(z, pr), ps)); - return b->Select(b->Lt(x, zero), b->Sub(two, y), y); + return Select(Lt(x, zero), Sub(two, y), y); } // Compute a polynomial approximation of the error function. -xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type) { - xla::XlaOp z = b->Mul(x, x); - xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type); - xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type); - return b->Div(b->Mul(x, pt), pu); +XlaOp Erf(XlaOp x, PrimitiveType data_type) { + XlaOp z = Mul(x, x); + XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient, data_type); + XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient, data_type); + return Div(Mul(x, pt), pu); +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +XlaOp ErfInv(XlaOp x) { + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = ConstantR0(b, 1.0); + auto w = Neg(Log(Mul(Sub(one, x), Add(one, x)))); + + auto lt = Lt(w, ConstantR0(b, 5.0)); + auto coefficient = [&](int i) { + return Select( + lt, + Broadcast(ConstantR0(b, w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + Broadcast(ConstantR0(b, w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = Select(lt, Sub(w, ConstantR0(b, 2.5f)), + Sub(SqrtF32(w), ConstantR0(b, 3.0f))); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = Add(coefficient(i), Mul(p, w)); + } + return Mul(p, x); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f11cc003177c7eb68c32f9e618704a1ac7e63a73..d0e04bbb5eb5365ab3f45dcaf4d8c389d2e77fa1 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -53,21 +53,22 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder); // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); +XlaOp Any(XlaOp predicates); // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, - tensorflow::gtl::ArraySlice coefficients, - PrimitiveType data_type); +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients, + PrimitiveType data_type); // Compute an approximation of the error function complement (1 - erf(x)). -xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type); +XlaOp Erfc(XlaOp x, PrimitiveType data_type); // Compute an approximation of the error function. -xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x, - PrimitiveType data_type); +XlaOp Erf(XlaOp x, PrimitiveType data_type); + +// Compute an approximation of the inverse of the error function. +XlaOp ErfInv(XlaOp x); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc new file mode 100644 index 0000000000000000000000000000000000000000..cbe9e7fdd1330164f1f9c4520c2bb81e38f4ceb9 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +#include +#include + +namespace xla { + +namespace { + +template +XlaOp MakeIota(XlaBuilder* builder, int64 size) { + std::vector values(size); + for (int64 i = 0; i < size; ++i) { + values[i] = static_cast(i); + } + return xla::ConstantR1(builder, values); +} + +} // namespace + +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + switch (type) { + case S8: + return MakeIota(builder, size); + case S16: + return MakeIota(builder, size); + case S32: + return MakeIota(builder, size); + case S64: + return MakeIota(builder, size); + case U8: + return MakeIota(builder, size); + case U16: + return MakeIota(builder, size); + case U32: + return MakeIota(builder, size); + case U64: + return MakeIota(builder, size); + case BF16: + return MakeIota(builder, size); + case F16: + return MakeIota(builder, size); + case F32: + return MakeIota(builder, size); + case F64: + return MakeIota(builder, size); + case C64: + return MakeIota(builder, size); + default: + return builder->ReportError( + InvalidArgument("Unimplemented type for Iota: %s.", + PrimitiveType_Name(type).c_str())); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h new file mode 100644 index 0000000000000000000000000000000000000000..2a409ae31147a4a88367422ce31c9fbcb22fdbca --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc8a73e9d793ef8f65c321759e03b0de75edd500 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using NumericTest = ClientLibraryTestBase; + +XLA_TEST_F(NumericTest, Iota) { + XlaBuilder builder(TestName()); + Iota(&builder, S32, 10); + + ComputeAndCompareR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 3380af9f303b1dc2cec09aa37410ec40cdeaa526..731ad13b8d0e5d65acc316e72be9fe7d35e826a4 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -48,15 +48,15 @@ int64 DataSizeOfShape(const Shape& shape) { // Creates a XlaOp for an op what generates fake data with the given shape. XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { - return builder->Broadcast( - builder->ConstantLiteral(Literal::One(shape.element_type())), + return Broadcast( + ConstantLiteral(builder, Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } - return builder->Tuple(parts); + return Tuple(builder, parts); } std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index ae0308020d014e038d2f0fd7de6c5f372d6cbed1..5f9710914bd0ceff55f5b0a2db05e553ce8bd637 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -51,24 +51,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& host_computation_layout = - executable_->module_config().host_entry_computation_layout(); - const ComputationLayout& device_computation_layout = - executable_->module_config().device_entry_computation_layout(); + const ComputationLayout& computation_layout = + executable_->module_config().entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != host_computation_layout.parameter_count()) { + if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - host_computation_layout.parameter_count(), arguments.size()); - } - if (arguments.size() != device_computation_layout.parameter_count()) { - return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", - device_computation_layout.parameter_count(), arguments.size()); + computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, @@ -76,24 +69,10 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( - host_computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } - if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->on_device_shape())) { - return InvalidParameterArgument( - executable_.get(), i, - "Argument does not match device shape or layout of computation " - "parameter " - "%d: want %s, got %s", - i, - ShapeUtil::HumanString( - device_computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); - } } if (run_options.stream() != nullptr) { @@ -230,10 +209,9 @@ Status LocalExecutable::RecordResult(const ShapedBuffer* result, StatusOr> LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend_->stream_executor(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, + backend_->BorrowStream(shaped_buffer.device_ordinal())); + return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } @@ -288,19 +266,18 @@ StatusOr LocalClient::LiteralToShapedBuffer( TF_ASSIGN_OR_RETURN(auto scoped_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( literal.shape(), allocator, device_ordinal)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend().stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, literal, scoped_buffer)); + stream.get(), literal, scoped_buffer)); return std::move(scoped_buffer); } StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend().stream_executor(shaped_buffer.device_ordinal())); - return backend().transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( + shaped_buffer.device_ordinal())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 507a2dc5f088e159156f0ef3d663ba2819f6a2d4..ee00a9eada8dd906c26e07a4affccdaf544f1693 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -1,7 +1,5 @@ # Description: # The new XLA client libraries. -# -# This is NOT YET ready to use. licenses(["notice"]) # Apache 2.0 @@ -41,6 +39,7 @@ cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], deps = [ ":xla_computation", "//tensorflow/compiler/xla:execution_options_util", @@ -52,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index d7ebcf8bebc1f656b4965c833e0d42ccceb1b99f..4f683a4115c15c029697978f586bf8a45083f597 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -59,6 +60,36 @@ bool CanBeRoot(HloOpcode opcode) { } // namespace +XlaOp operator-(const XlaOp& x) { return Neg(x); } +XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); } +XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); } +XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); } +XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); } +XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); } + +XlaOp operator~(const XlaOp& x) { return Not(x); } +XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); } +XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); } +XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); } +XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); } + +XlaOp operator>>(const XlaOp& x, const XlaOp& y) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Argument to >> operator does not have an integral type (%s).", + ShapeUtil::HumanString(shape).c_str()); + } + if (ShapeUtil::ElementIsSigned(shape)) { + return ShiftRightArithmetic(x, y); + } else { + return ShiftRightLogical(x, y); + } + }); +} + StatusOr XlaBuilder::GetShape(const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); @@ -81,7 +112,7 @@ XlaBuilder::XlaBuilder(const string& computation_name) XlaBuilder::~XlaBuilder() {} -void XlaBuilder::NoteError(const Status& error) { +XlaOp XlaBuilder::ReportError(const Status& error) { CHECK(!error.ok()); if (die_immediately_on_error_) { LOG(FATAL) << "error building computation: " << error; @@ -91,19 +122,22 @@ void XlaBuilder::NoteError(const Status& error) { first_error_ = error; first_error_backtrace_.CreateCurrent(/*skip_count=*/1); } + return XlaOp(this); } -XlaOp XlaBuilder::NoteErrorOrReturn( - const std::function()>& op_creator) { +XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr& op) { if (!first_error_.ok()) { - return {}; + return XlaOp(this); } - auto op = op_creator(); if (!op.ok()) { - NoteError(op.status()); - return {}; + return ReportError(op.status()); } - return op.ConsumeValueOrDie(); + return op.ValueOrDie(); +} + +XlaOp XlaBuilder::ReportErrorOrReturn( + const std::function()>& op_creator) { + return ReportErrorOrReturn(op_creator()); } StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { @@ -207,7 +241,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); if (!build_status.ok()) { - parent_builder_->NoteError( + parent_builder_->ReportError( AddStatus(build_status.status(), tensorflow::strings::StrCat("error from: ", name_))); return {}; @@ -315,7 +349,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, } XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -327,7 +361,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { XlaOp XlaBuilder::BinaryOp( HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -383,7 +417,7 @@ XlaOp XlaBuilder::BinaryOp( XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -430,7 +464,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); *instr.mutable_literal() = literal.ToProto(); @@ -440,7 +474,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice operands) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); @@ -461,7 +495,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { return InvalidArgument("parameter %lld already registered", @@ -476,7 +510,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, @@ -510,7 +544,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -530,7 +564,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); std::vector starts(ShapeUtil::Rank(shape), 0); std::vector limits(shape.dimensions().begin(), @@ -545,7 +579,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -566,7 +600,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -584,7 +618,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, int64 dimension) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; @@ -603,7 +637,7 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -624,7 +658,7 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( @@ -638,7 +672,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -648,7 +682,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Collapse(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus // enqueueing a trivial reshape. @@ -690,7 +724,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { - NoteErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); @@ -704,7 +738,7 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, } XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); @@ -718,7 +752,7 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { } XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); if (!ShapeUtil::IsTuple(tuple_shape)) { @@ -767,7 +801,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); DotDimensionNumbers dimension_numbers; @@ -780,7 +814,7 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -859,7 +893,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -905,7 +939,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -992,7 +1026,7 @@ StatusOr XlaBuilder::MakeWindow( XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice fft_length) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1009,23 +1043,69 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, } XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } - *instr.mutable_shape() = shape; + const Shape infeed_instruction_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + *instr.mutable_shape() = infeed_instruction_shape; instr.set_infeed_config(config); - return AddInstruction(std::move(instr), HloOpcode::kInfeed); + + if (ShapeUtil::IsArray(shape) && sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { + // TODO(b/110793772): Support tiled array-shaped infeeds. + return InvalidArgument( + "Tiled sharding is not yet supported for array-shaped infeeds"); + } + + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return InvalidArgument( + "Replicated sharding is not yet supported for infeeds"); + } + + // The sharding is set by the client according to the data tuple shape. + // However, the shape of the infeed instruction is a tuple containing the + // data and a token. For tuple sharding type, the sharding must be changed + // to accommodate the token. + XlaOp infeed; + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) { + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + OpSharding infeed_instruction_sharding = *sharding(); + // Arbitrarily assign the token to device 0. + *infeed_instruction_sharding.add_tuple_shardings() = + sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, + infeed_instruction_sharding); + TF_ASSIGN_OR_RETURN(infeed, + AddInstruction(std::move(instr), HloOpcode::kInfeed)); + } else { + TF_ASSIGN_OR_RETURN(infeed, + AddInstruction(std::move(instr), HloOpcode::kInfeed)); + } + + // The infeed instruction produces a tuple of the infed data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto infeed_data; + *infeed_data.mutable_shape() = shape; + infeed_data.set_tuple_index(0); + return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement, + {infeed}); }); } void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config) { - NoteErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1042,14 +1122,33 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, instr.set_outfeed_config(outfeed_config); - return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}); + TF_RETURN_IF_ERROR( + AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}) + .status()); + + // The outfeed instruction produces a token. However, existing users expect + // a nil shape (empty tuple). This should only be relevant if the outfeed is + // the root of a computation. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto tuple_instr; + *tuple_instr.mutable_shape() = ShapeUtil::MakeNil(); + + // The dummy tuple should have no sharding. + { + XlaScopedShardingAssignment scoped_sharding(this, OpSharding()); + TF_ASSIGN_OR_RETURN( + XlaOp empty_tuple, + AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {})); + return empty_tuple; + } }); } XlaOp XlaBuilder::CustomCall(const string& call_target_name, tensorflow::gtl::ArraySlice operands, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (tensorflow::str_util::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1066,7 +1165,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape; instr.set_channel_name(channel_name); @@ -1120,11 +1219,9 @@ XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } -// TODO(b/65209188): Create a dedicated lowering for Xor. XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return Or(And(Not(lhs), rhs, broadcast_dimensions), - And(lhs, Not(rhs), broadcast_dimensions)); + return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Not(const XlaOp& operand) { @@ -1223,7 +1320,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice permutation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1238,7 +1335,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, XlaOp XlaBuilder::Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1251,8 +1348,25 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSort, operand); +XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + operand_shape_ptrs.push_back(&keys_shape); + Shape values_shape; + if (values.has_value()) { + TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); + operand_shape_ptrs.push_back(&values_shape); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + return values.has_value() + ? AddInstruction(std::move(instr), HloOpcode::kSort, + {keys, *values}) + : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + }); } XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { @@ -1267,7 +1381,7 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1279,7 +1393,7 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1313,13 +1427,12 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, const XlaComputation& computation, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice static_operands) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); } HloInstructionProto instr; - std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1331,16 +1444,32 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, dimensions)); + const Shape& output_shape = instr.shape(); + const int64 output_rank = ShapeUtil::Rank(output_shape); AddCalledComputation(computation, &instr); + std::vector new_operands(operands.begin(), operands.end()); + for (XlaOp& new_operand : new_operands) { + TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); + const int64 rank = ShapeUtil::Rank(shape); + if (rank != output_rank) { + TF_ASSIGN_OR_RETURN(new_operand, + InDimBroadcast(output_shape, new_operand, {})); + TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand)); + } + if (!ShapeUtil::SameDimensions(output_shape, shape)) { + TF_ASSIGN_OR_RETURN(new_operand, + AddBroadcastSequence(output_shape, new_operand)); + } + } - return AddInstruction(std::move(instr), HloOpcode::kMap, operands); + return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands); }); } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Check the number of parameters per RNG distribution. @@ -1378,7 +1507,7 @@ XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Infer shape. @@ -1400,7 +1529,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); @@ -1425,7 +1554,7 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); @@ -1457,7 +1586,7 @@ XlaOp XlaBuilder::Reduce( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1482,7 +1611,7 @@ XlaOp XlaBuilder::Reduce( XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); @@ -1495,7 +1624,7 @@ XlaOp XlaBuilder::ReduceWindow( const XlaComputation& computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1518,7 +1647,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1542,7 +1671,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1565,7 +1694,7 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1590,7 +1719,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1614,7 +1743,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, tensorflow::gtl::ArraySlice replica_group_ids) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); auto b = CreateSubBuilder("sum"); @@ -1630,7 +1759,7 @@ XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, tensorflow::gtl::ArraySlice replica_group_ids, const tensorflow::gtl::optional& channel_id) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { if (channel_id.has_value()) { return Unimplemented("channel_id is not supported in AllReduce"); } @@ -1657,7 +1786,7 @@ XlaOp XlaBuilder::SelectAndScatter( tensorflow::gtl::ArraySlice window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, @@ -1674,7 +1803,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1702,7 +1831,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -1716,7 +1845,7 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, } void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { - NoteErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Send instruction produces a tuple of {aliased operand, U32 context}. @@ -1737,7 +1866,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { } XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Recv instruction produces a tuple of {receive buffer, U32 context}. @@ -1992,9 +2121,502 @@ StatusOr XlaBuilder::LookUpInstruction( return &instructions_[op.handle()]; } -XlaOp XlaBuilder::UnimplementedOp() { - NoteError(Unimplemented("Op not implemented")); - return {}; +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name) { + return builder->Parameter(parameter_number, shape, name); +} + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { + return builder->ConstantLiteral(literal); +} + +XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes) { + return operand.builder()->Broadcast(operand, broadcast_sizes); +} + +XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return operand.builder()->Pad(operand, padding_value, padding_config); +} + +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return operand.builder()->Reshape(operand, dimensions, new_sizes); +} + +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes) { + return operand.builder()->Reshape(operand, new_sizes); +} + +XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return operand.builder()->Collapse(operand, dimensions); +} + +XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return operand.builder()->Slice(operand, start_indices, limit_indices, + strides); +} + +XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { + return operand.builder()->SliceInDim(operand, start_index, limit_index, + stride, dimno); +} + +XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); +} + +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); +} + +XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return builder->ConcatInDim(operands, dimension); +} + +void Trace(const string& tag, const XlaOp& operand) { + return operand.builder()->Trace(tag, operand); +} + +XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { + return pred.builder()->Select(pred, on_true, on_false); +} + +XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements) { + return builder->Tuple(elements); +} + +XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { + return tuple_data.builder()->GetTupleElement(tuple_data, index); +} + +XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); +} + +XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); +} + +XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); +} + +XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); +} + +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); +} + +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); +} + +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) { + return lhs.builder()->Dot(lhs, rhs); +} + +XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers); +} + +XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + return lhs.builder()->Conv(lhs, rhs, window_strides, padding); +} + +XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, + padding); +} + +XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, + padding, dimension_numbers); +} + +XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, + dimension_numbers); +} + +XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers); +} + +XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) { + return operand.builder()->Fft(operand, fft_type, fft_length); +} + +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { + return builder->Infeed(shape, config); +} + +void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config); +} + +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands) { + return builder->Call(computation, operands); +} + +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + return builder->CustomCall(call_target_name, operands, shape); +} + +XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape) { + return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape); +} + +XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return real.builder()->Complex(real, imag, broadcast_dimensions); +} + +XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } + +XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); +} + +XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); +} + +XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); +} + +XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); +} + +XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); +} + +XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); +} + +XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); +} + +XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->And(lhs, rhs, broadcast_dimensions); +} + +XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); +} + +XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); +} + +XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } + +XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); +} + +XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); +} + +XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); +} + +XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return operand.builder()->Reduce(operand, init_value, computation, + dimensions_to_reduce); +} + +XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return operand.builder()->ReduceAll(operand, init_value, computation); +} + +XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + return operand.builder()->ReduceWindow(operand, init_value, computation, + window_dimensions, window_strides, + padding); +} + +XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return operand.builder()->ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + padding); +} + +XlaOp CrossReplicaSum(const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { + return operand.builder()->CrossReplicaSum(operand, replica_group_ids); +} + +XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return operand.builder()->CrossReplicaSum(operand, computation, + replica_group_ids, channel_id); +} + +XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { + return operand.builder()->SelectAndScatter(operand, select, window_dimensions, + window_strides, padding, source, + init_value, scatter); +} + +XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return operand.builder()->SelectAndScatterWithGeneralPadding( + operand, select, window_dimensions, window_strides, padding, source, + init_value, scatter); +} + +XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } + +XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return y.builder()->Atan2(y, x, broadcast_dimensions); +} + +XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); } + +XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); } + +XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); } + +XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); } + +XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); } + +XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); } + +XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); } + +XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); } + +XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); } + +XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); } + +XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); } + +XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); } + +XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } + +XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } + +XlaOp SqrtF32(const XlaOp& operand) { + return operand.builder()->SqrtF32(operand); +} + +XlaOp SquareF32(const XlaOp& operand) { + return operand.builder()->SquareF32(operand); +} + +XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); +} + +XlaOp IsFinite(const XlaOp& operand) { + return operand.builder()->IsFinite(operand); +} + +XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { + return operand.builder()->ConvertElementType(operand, new_element_type); +} + +XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { + return operand.builder()->BitcastConvertType(operand, new_element_type); +} + +XlaOp ReciprocalF32(const XlaOp& operand) { + return operand.builder()->ReciprocalF32(operand); +} + +XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } + +XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation) { + return operand.builder()->Transpose(operand, permutation); +} + +XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return operand.builder()->Rev(operand, dimensions); +} + +XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values) { + return keys.builder()->Sort(keys, std::move(values)); +} + +XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { + return min.builder()->Clamp(min, operand, max); +} + +XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return builder->Map(operands, computation, dimensions, static_operands); +} + +XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) { + return mu.builder()->RngNormal(mu, sigma, shape); +} + +XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) { + return a.builder()->RngUniform(a, b, shape); +} + +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init) { + return init.builder()->While(condition, body, init); +} + +XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return predicate.builder()->Conditional(predicate, true_operand, + true_computation, false_operand, + false_computation); +} + +XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return operand.builder()->ReducePrecision(operand, exponent_bits, + mantissa_bits); +} + +XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + return input.builder()->Gather(input, gather_indices, dimension_numbers, + window_bounds); +} + +void Send(const XlaOp& operand, const ChannelHandle& handle) { + return operand.builder()->Send(operand, handle); +} + +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle) { + return builder->Recv(shape, handle); +} + +XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon, + feature_index); +} + +XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormInference( + operand, scale, offset, mean, variance, epsilon, feature_index); +} + +XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var, + grad_output, epsilon, feature_index); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 0329e42ed1aef8edd1537e888ddcd78f08584407..ac6ad8734948c4ae898e0dbf24d422c0628c294f 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/client/padding.h" @@ -46,17 +47,23 @@ class XlaBuilder; // instruction as an operand. class XlaOp { public: - XlaOp() : handle_(0), builder_(nullptr) {} - ~XlaOp() {} + XlaOp() : handle_(-1), builder_(nullptr) { + static_assert(std::is_trivially_destructible::value, + "XlaOp should be trivially destructible"); + } + ~XlaOp() = default; - const XlaBuilder* builder() const { return builder_; } + XlaBuilder* builder() const { return builder_; } - bool operator==(const XlaOp& rhs) const { - return handle_ == rhs.handle_ && builder_ == rhs.builder_; - } + // Returns true if the XlaOp represents valid, non-erroneous value. + bool valid() const { return handle_ >= 0; } + + // Returns true if the XlaOp was created by the XlaOp() constructor and + // not returned by a builder. + bool IsUninitialized() const { return builder_ == nullptr; } - bool operator!=(const XlaOp& rhs) const { - return handle_ != rhs.handle_ || builder_ != rhs.builder_; + bool IsIdenticalTo(const XlaOp& rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; } friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { @@ -65,6 +72,7 @@ class XlaOp { } private: + explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} XlaOp(int64 handle, XlaBuilder* builder) : handle_(handle), builder_(builder) {} @@ -72,10 +80,38 @@ class XlaOp { friend class XlaBuilder; + // < 0 means "invalid handle". int64 handle_; - XlaBuilder* builder_; // Not owned. + + // Not owned. Non-null for any handle returned by XlaBuilder, even if the + // handle is invalid. + XlaBuilder* builder_; }; +// Arithmetic operator overloads for the XlaOp type. +XlaOp operator-(const XlaOp& x); +XlaOp operator+(const XlaOp& x, const XlaOp& y); +XlaOp operator-(const XlaOp& x, const XlaOp& y); +XlaOp operator*(const XlaOp& x, const XlaOp& y); +XlaOp operator/(const XlaOp& x, const XlaOp& y); +XlaOp operator%(const XlaOp& x, const XlaOp& y); + +// Bitwise operator overloads for the XlaOp type. +XlaOp operator~(const XlaOp& x); +XlaOp operator&(const XlaOp& x, const XlaOp& y); +XlaOp operator|(const XlaOp& x, const XlaOp& y); +XlaOp operator^(const XlaOp& x, const XlaOp& y); +XlaOp operator<<(const XlaOp& x, const XlaOp& y); +// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs +// a right logical shift. +XlaOp operator>>(const XlaOp& x, const XlaOp& y); + +// We don't overload the relational operators (==, !=, <, <=, >, >=) because the +// semantics might be surprising since their result types are usually 'bool'. +// Further programmers may expect == to be a structural equality. +// We also choose not to overload any of the mutating operators (e.g., +=, -=) +// because the semantics might be misleading — XLA computations are immutable. + // A convenient interface for building up computations. // // Thread-compatible. @@ -122,6 +158,93 @@ class XlaBuilder { die_immediately_on_error_ = enabled; } + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. + StatusOr Build(); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr GetProgramShape() const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const Status& error); + + // A helper function that converts a StatusOr into an XlaOp. + // If the Status was an error, reports the error to builder and returns an + // invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const StatusOr& op); + + // A helper function that runs a function that returns a StatusOr and + // returns an XlaOp. + XlaOp ReportErrorOrReturn(const std::function()>& op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr IsConstant(const XlaOp& operand) const; + + private: // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -342,26 +465,6 @@ class XlaBuilder { XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers); - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static Status Validate(const ConvolutionDimensionNumbers& dnum); - // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, @@ -681,7 +784,18 @@ class XlaBuilder { tensorflow::gtl::ArraySlice dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. - XlaOp Sort(const XlaOp& operand); + // If only keys are provided: + // * The keys must be a rank-1 tensor (i.e. an array). + // * The result is a sorted array of keys. + // + // If both keys and values are provided: + // * The keys and the values must be rank-1 tensors with the same dimensions. + // The element types of the tensors may be different. + // * The result is a tuple that consists of a sorted array of keys as the + // first element, and an array with their corresponding values as the second + // element. + XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values = + tensorflow::gtl::nullopt); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -728,14 +842,6 @@ class XlaBuilder { // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - StatusOr IsConstant(const XlaOp& operand) const; - // Normalizes operand across spatial and batch dimensions for each feature. // // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` @@ -774,47 +880,6 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr CreateSubBuilder(const string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - // Returns the shape of the given op. - StatusOr GetShape(const XlaOp& op) const; - - // Returns the (inferred) result for the current computation's shape. - StatusOr GetProgramShape() const; - - private: StatusOr AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice operands = {}); @@ -822,17 +887,6 @@ class XlaBuilder { void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - XlaOp NoteErrorOrReturn(const std::function()>& op_creator); - - // Helper method that creates an empty op and notes error. - XlaOp UnimplementedOp(); - StatusOr LookUpInstruction(const XlaOp& op) const; // Internal helper method that does the building for an arbitrary unary op. @@ -928,8 +982,958 @@ class XlaBuilder { bool die_immediately_on_error_ = false; XlaBuilder* parent_builder_{nullptr}; + + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, + const Shape& shape, const string& name); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + template + friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); + template + friend XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values); + friend XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values); + template + friend XlaOp ConstantR2( + XlaBuilder* builder, + std::initializer_list> values); + template + friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); + template + friend XlaOp ConstantFromArray(XlaBuilder* builder, + const Array& values); + template + friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); + template + friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); + template + friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); + template + friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); + template + friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); + template + friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + + template + friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + + friend XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + friend XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); + + friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + friend void Trace(const string& tag, const XlaOp& operand); + + friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false); + friend XlaOp Tuple(XlaBuilder* builder, + tensorflow::gtl::ArraySlice elements); + friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + friend XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config); + friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + friend XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Conj(const XlaOp& operand); + friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Not(const XlaOp& operand); + friend XlaOp ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id); + friend XlaOp SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp Abs(const XlaOp& operand); + friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Exp(const XlaOp& operand); + friend XlaOp Expm1(const XlaOp& operand); + friend XlaOp Floor(const XlaOp& operand); + friend XlaOp Ceil(const XlaOp& operand); + friend XlaOp Round(const XlaOp& operand); + friend XlaOp Log(const XlaOp& operand); + friend XlaOp Log1p(const XlaOp& operand); + friend XlaOp Sign(const XlaOp& operand); + friend XlaOp Clz(const XlaOp& operand); + friend XlaOp Cos(const XlaOp& operand); + friend XlaOp Sin(const XlaOp& operand); + friend XlaOp Tanh(const XlaOp& operand); + friend XlaOp Real(const XlaOp& operand); + friend XlaOp Imag(const XlaOp& operand); + friend XlaOp SqrtF32(const XlaOp& operand); + friend XlaOp SquareF32(const XlaOp& operand); + friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp IsFinite(const XlaOp& operand); + friend XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp ReciprocalF32(const XlaOp& operand); + friend XlaOp Neg(const XlaOp& operand); + friend XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + friend XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values); + friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + friend XlaOp Map(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); + friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape); + friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init); + friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + friend void Send(const XlaOp& operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); }; +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; +}; + +// Free functions for building XlaOps. The intention is that these will +// become the public API for building XlaOps rather than calling methods on +// XlaBuilder directly. + +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name); + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); + +// Enqueues a constant onto the computation. Methods are templated on the +// native host type (NativeT) which corresponds to a specific XLA +// PrimitiveType as given in the following table: +// +// Native Type PrimitiveType +// ----------------------------- +// bool PRED +// int32 S32 +// int64 S64 +// uint32 U32 +// uint64 U64 +// float F32 +// double F64 +// +// Note: not all primitive types defined in xla_data.proto have a +// corresponding native type yet. +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value); +template +XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values); +XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values); +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values); +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + +// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the +// computation. The vector has size 'length' and every element has the value +// 'value'. +template +XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + +// Adds dimensions to an array by duplicating the data in the array. +// +// The new dimensions are inserted on the left, i.e. if +// broadcast_sizes has values {a0, ..., aN} and the operand shape +// has dimensions {b0, ..., bM} then the shape of the output has +// dimensions {a0, ..., aN, b0, ..., bM}. +// +// The new dimensions index into copies of the operand, i.e. +// +// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] +XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + +// Enqueues a pad operation onto the computation that pads the given value on +// the edges as well as between the elements of the input. padding_config +// specifies the padding amount for each dimension. +XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + +// Enqueues an operation onto the computation that flattens the operand based +// on the dimension order (major/slowest-varying to minor/fastest-varying) +// given, followed by reshaping it into the shape with the given dimension +// sizes (also major to minor). Conceptually, this is a limited form of +// "shape casting". +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + +// Enqueues an operation onto the computation that collapses the operand, from +// first to last dimension (C order), then reshapes it to the given dimension +// sizes. Conceptually, this is a limited form of "shape casting". +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + +// Wrapper for Reshape. +// Enqueues an operation to collapse the provided dimensions; e.g. an +// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to +// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must +// be a consecutive, in-order subsequence of the operand dimensions. +// +// Note that collapsing a single dimension does nothing: +// +// {256} collapsing {0} => {256} +// {1} collapsing {0} => {1} +// +// Collapsing multiple dimensions produces a single result dimension: +// +// {256, 2} collapsing {0,1} => {512} +// {256, 2, 3} collapsing {0,1} => {512, 3} +// +// This could potentially cause data to be moved -- it provides a more +// structured form of reshaping than an arbitrary Reshape operation. +XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + +// Enqueues a slice operation onto the computation that slices the operand +// from the start indices to the limit indices; e.g. +// +// x +// [ 0 1 2 3 ] +// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] +// [ 8 9 a b ] +// +// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D +// range notation. +// The strides parameter determines the stride over the slice +XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + +// Enqueues a slice operation in a given dimension, taking all other +// dimensions as they are; e.g. if dimno is 1 from start_index 2 to +// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand +// for: +// +// array[:, 2:4:1, :] +XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + +// Enqueues a slice operation onto the computation that slices the 'operand' +// from dynamic start indices which are passed in 'start_indices'. +// The size of the slice in each dimension is passed in 'slice_sizes', +// which specify the end point of exclusive slice intervals in each +// dimension [start, start + size). +// The shape of 'start_indices' must be rank == 1, with dimension size +// equal to the rank of the 'operand'. +// Slice index calculations are computed modulo input dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + +// Enqueues a dynamic update slice operation onto the computation, which +// updates a slice of 'operand' with 'update' at dynamic 'start_indices'. +// The shape of 'update' determines the shape of the slice of 'operand' +// which is updated. +// The indices specified in 'start_indices' specify the offset of the slice +// of 'operand' which is updated. +// +// update = {10, 11} // calculated at runtime. +// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] +// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] +// [7 8 9] [7 8 9 ] +// +// The shape of 'start_indices' must be rank == 1, with dimension size +// equal to the rank of the 'operand'. +// Slice index calculations are computed modulo update dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + +// Enqueues a concatenate instruction onto the computation. 'operands' must +// have >= 1 entry. +XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, int64 dimension); + +// Enqueue a tracing operation onto the computation; the computation will emit +// a logging message with the operand. +void Trace(const string& tag, const XlaOp& operand); + +// Enqueues a conditional-move-like select operation onto the computation; +// predicated on pred, selects between on_true and on_false. +XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + +// Enqueues a tuple-creation instruction onto the computation. +XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements); + +// Enqueues a tuple-element-get instruction onto the computation. +XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + +// Enqueues an equal-to comparison instruction onto the computation. +XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a not-equal comparison instruction onto the computation. +XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a greater-or-equal comparison instruction onto the computation. +XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a greater-than comparison instruction onto the computation. +XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a less-than comparison instruction onto the computation. +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a less-or-equal comparison instruction onto the computation. +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a dot instruction onto the computation. +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + +// Enqueues a general dot instruction onto the computation. +XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, which uses the +// default convolution dimension numbers. +XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration in the format returned by MakePadding(). +XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided dimension numbers configuration. +XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration as well as the dimension numbers. +XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration, dilation factors and dimension numbers. +XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues an FFT instruction onto the computation, of the given type and +// with the given FFT length. +XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + +// Enqueues an infeed instruction onto the computation, which writes data of +// the given shape to the infeed buffer of the device. +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config = ""); + +// Enqueues an outfeed instruction onto the computation. This instruction +// generates outgoing data transfers for the given data. +// +// shape_with_layout communicates the laid out shape that we want to outfeed +// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error +// will occur. +void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + +// Enqueues a call instruction onto the computation. +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + +// Enqueues a custom call instruction onto the computation. +// During code generation, a call instruction is emitted which targets a +// symbol with the name |call_target_name|. The |operands| are passed to the +// call instruction. |shape| is the resultant shape. +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + +// Enqueues a pseudo-op to represent host-side computation data-dependencies. +// During code generation, host send and receive operations will be generated +// to transfer |operands| to the host and a single result of |shape| back to +// the device. Host send/recv operations are emitted using |channel_name|. +// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO +// instruction scheduling. +XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + +// The following methods enqueue element-wise binary arithmetic operations +// onto the computation. The shapes of the operands have to match unless one +// of the operands is a scalar, or an explicit broadcast dimension is given +// (see g3doc for more details). + +// Enqueues a complex compose instruction onto the computation. +XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a complex conjugate instruction onto the computation. +XlaOp Conj(const XlaOp& operand); + +// Enqueues an add instruction onto the computation. +XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a subtract instruction onto the computation. +XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a multiply instruction onto the computation. +XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a divide instruction onto the computation. +XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a remainder instruction onto the computation. +XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a max instruction onto the computation. +XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a min instruction onto the computation. +XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Element-wise logical operators +XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Not(const XlaOp& operand); + +XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); +XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Reduces an array among the provided dimensions, given "computation" as a +// reduction operator. +XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + +// Convenience wrapper around the above that reduces all the dimensions in the +// operand shape. +XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + +// Enqueues a windowed reduce instruction onto the computation. +XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + +// As ReduceWindow(), but the padding is given in the format +// returned by MakePadding(). +XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + +// Returns the sum of the operand value within each subgroup of replicas. All +// replicas supply one input to the sum and all replicas receive the resulting +// sum for each subgroup. +XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + +// Enqueues an operation that do an AllReduce of the operand cross cores. Here +// AllReduce means doing a reduction on the input operand cross cores and then +// broadcasting the reduction result to those cores. The reduction function is +// defined by `computation`, which should be a commutative computation on +// scalars, e.g., add, min, or max. The way that AllReduce is applied is +// configured by: +// +// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all +// replicas belong to one group. Allreduce will be applied within subgroups. +// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, +// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// +// - `channel_id`: for Allreduce nodes from different models, if they have the +// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be +// applied cross models. +// +// TODO(b/79737069): Rename this to AllReduce when it's ready to use. +XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& + channel_id = tensorflow::gtl::nullopt); + +// Enqueues an operation that scatters the `source` array to the selected +// indices of each window. +XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); + +// As SelectAndScatter(), but the padding is given in the format +// returned by MakePadding(). +XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + +// Enqueues an abs instruction onto the computation. +XlaOp Abs(const XlaOp& operand); + +// Enqueues a atan2 instruction onto the computation. +XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues an exp instruction onto the computation. +XlaOp Exp(const XlaOp& operand); + +// Enqueues an expm1 instruction onto the computation. +XlaOp Expm1(const XlaOp& operand); + +// Enqueues a floor instruction onto the computation. +XlaOp Floor(const XlaOp& operand); + +// Enqueues a ceil instruction onto the computation. +XlaOp Ceil(const XlaOp& operand); + +// Enqueues a round instruction onto the computation, rounding to nearest even +// with half-way cases rounding away from zero. +XlaOp Round(const XlaOp& operand); + +// Enqueues an log instruction (natural logarithm) onto the computation. +XlaOp Log(const XlaOp& operand); + +// Enqueues an log1p instruction (log(x+1)) onto the computation. +XlaOp Log1p(const XlaOp& operand); + +// Enqueues a sign instruction onto the computation. +XlaOp Sign(const XlaOp& operand); + +// Enqueues a count leading zeros instruction onto the computation. +XlaOp Clz(const XlaOp& operand); + +// Enqueues a cosine instruction onto the computation. +XlaOp Cos(const XlaOp& operand); + +// Enqueues a sine instruction onto the computation. +XlaOp Sin(const XlaOp& operand); + +// Enqueues a tanh instruction onto the computation. +XlaOp Tanh(const XlaOp& operand); + +// Enqueues a real-part instruction onto the computation. +XlaOp Real(const XlaOp& operand); + +// Enqueues an imaginary-part instruction onto the computation. +XlaOp Imag(const XlaOp& operand); + +// Enqueues a float32 sqrt instruction onto the computation. +// (float32 is specified as there is an implicit float32 0.5f constant +// exponent). +XlaOp SqrtF32(const XlaOp& operand); + +// Enqueues a float32 square instruction onto the computation. +// (float32 is specified as there is an implicit float32 2.0f constant +// exponent). +XlaOp SquareF32(const XlaOp& operand); + +// Enqueues a lhs^rhs computation onto the computation. +XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues an operator that tests if the operand's values are finite, i.e., +// not Inf or NaN. Defined only for floating-point types. Returns an array of +// booleans with the same shape where entries are true iff the corresponding +// entry was NaN. +XlaOp IsFinite(const XlaOp& operand); + +// Enqueues a convert instruction onto the computation that changes the +// element type of the operand array to primitive_type. +XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); + +// Enqueues a no-op instruction onto the computation that changes +// the element type of the operand array to primitive_type. The +// bit-widths of the source and destination element types must be +// identical. +XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); + +// Enqueues a float32 reciprocal instruction onto the computation. +// (float32 is specified as there is an implicit float32 -1.0f constant +// exponent). +// +// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the +// shape of the operand. +XlaOp ReciprocalF32(const XlaOp& operand); + +// Enqueues a negate instruction onto the computation. +XlaOp Neg(const XlaOp& operand); + +// Enqueues a transpose instruction onto the computation. +XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + +// Enqueues a reverse instruction onto the computation. The order of the +// elements in the given dimensions is reversed (i.e., the element at index i +// is moved to index dimension_size - 1 - i). +XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); + +// * The result is a sorted array of keys. +// +// If both keys and values are provided: +// * The keys and the values must be rank-1 tensors with the same dimensions. +// The element types of the tensors may be different. +// * The result is a tuple that consists of a sorted array of keys as the +// first element, and an array with their corresponding values as the second +// element. +XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional values = tensorflow::gtl::nullopt); + +// Enqueues a clamp instruction onto the computation. +XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + +// Enqueues a map instruction onto the computation. +XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + +// Enqueues a N(mu, sigma) random number generation instruction onto the +// computation. +XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + +// Enqueues a U(a, b) random number generation instruction onto the +// computation. Returns values in the semi-open interval [a, b). +XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + +// Enqueues a while node onto the computation. +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + +// Enqueues a conditional node onto the computation. +XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + +// Enqueues a ReducePrecision node onto the computation. +XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + +// Enqueues a Gather node onto the computation. +XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + +// Enqueues a Send node onto the computation, to send the given operand to +// a Recv instruction that shares the same channel handle. +void Send(const XlaOp& operand, const ChannelHandle& handle); + +// Enqueues a Recv node onto the computation. The data comes from a Send +// instruction that shares the same channel handle and its shape must +// be the same as the given shape. +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// Returns a tuple (normalized, batch_mean, batch_var) where `normalized` +// is the normalized result and batch_mean and batch_var are the mean and +// variance, respectively, across batch for the operand. +XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// `BatchNormInference` is equivalent to calling `BatchNormTraining` without +// computing `mean` and `variance` for each batch inside the operation. It +// uses the input `mean` and `variance` instead as estimated values. The +// purpose of this op is to reduce latency in inference, hence the name +// `BatchNormInference`. +// +// The output has the same shape as `operand`, and contains the normalized +// values for each batch. +XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + +// Calculates the gradients of a batch norm op. +// +// The inputs `batch_mean` and `batch_var` represent the mean and variance +// across the batch. +// +// Returns a tuple of three elements: +// - grad_operand: Gradient with respect to input `operand` +// - grad_offset: Gradient with respect to input `offset` +// - grad_scale: Gradient with respect to input `scale` +XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + +// Implementation details below this point. + template XlaOp XlaBuilder::ConstantR0(NativeT value) { return ConstantLiteral(*Literal::CreateR0(value)); @@ -1005,34 +2009,93 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { return ConstantFromArray(values); } -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class XlaScopedShardingAssignment { - public: - XlaScopedShardingAssignment(xla::XlaBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } +// Free function template implementations. - XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; - XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = - delete; +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { + return ConstantLiteral(builder, *Literal::CreateR0(value)); +} - ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } +template +XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(builder, *Literal::CreateR1(values)); +} - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } +template +XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(builder, literal); +} - xla::XlaBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; -}; +inline XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values) { + return ConstantLiteral(builder, *Literal::CreateR1(values)); +} + +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values) { + return ConstantLiteral(builder, *Literal::CreateR2(values)); +} + +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout) { + return ConstantLiteral( + builder, *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { + return ConstantLiteral(builder, *Literal::CreateFromArray(values)); +} + +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout) { + return ConstantLiteral( + builder, *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values) { + return ConstantLiteral(builder, + *Literal::CreateR2FromArray2D(values)); +} + +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout) { + return ConstantLiteral( + builder, + *Literal::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values) { + return ConstantFromArray(builder, values); +} + +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout) { + return ConstantFromArrayWithLayout(builder, values, layout); +} + +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values) { + return ConstantFromArray(builder, values); +} } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index 2df3ea3af0d4fcfb9bc803feebd96f09042ab1f3..3b8beb2c7840e23752b5f47bbc5f55d89751884d 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -53,16 +53,86 @@ class XlaBuilderTest : public ::testing::Test { TEST_F(XlaBuilderTest, OnePlusTwo) { XlaBuilder b(TestName()); - b.Add(b.ConstantR0(1.0), b.ConstantR0(2.0)); + Add(ConstantR0(&b, 1.0), ConstantR0(&b, 2.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); } +TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { + auto test_unary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant())); + test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant())); +} + +TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { + auto test_binary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + + test_binary_operator([](XlaOp x, XlaOp y) { return x + y; }, + op::Add(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x - y; }, + op::Subtract(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x * y; }, + op::Multiply(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x / y; }, + op::Divide(op::Constant(), op::Constant())); + + test_binary_operator([](XlaOp x, XlaOp y) { return x & y; }, + op::And(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x | y; }, + op::Or(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; }, + op::Xor(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x << y; }, + op::ShiftLeft(op::Constant(), op::Constant())); + test_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightArithmetic(op::Constant(), op::Constant())); + + auto test_unsigned_binary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unsigned_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightLogical(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { + XlaBuilder b(TestName()); + ConstantR0(&b, 1) >> ConstantR0(&b, 2); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Argument to >> operator does not have an integral type")); +} + TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); - b.Add(x, b.ConstantR0(1.0)); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + Add(x, ConstantR0(&b, 1.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); @@ -72,9 +142,9 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { XlaBuilder b(TestName()); const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); - auto x = b.Parameter(0, x_shape, "x"); - auto y = b.Parameter(1, y_shape, "y"); - auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + auto x = Parameter(&b, 0, x_shape, "x"); + auto y = Parameter(&b, 1, y_shape, "y"); + auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1}); TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); @@ -86,8 +156,8 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { TEST_F(XlaBuilderTest, XPlusX) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); - b.Add(x, x); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); + Add(x, x); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); @@ -95,9 +165,9 @@ TEST_F(XlaBuilderTest, XPlusX) { TEST_F(XlaBuilderTest, ShapeInferenceError) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); - b.Add(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); + Add(x, y); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); @@ -105,12 +175,12 @@ TEST_F(XlaBuilderTest, ShapeInferenceError) { TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { XlaBuilder b_call("add"); - b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x"); XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); - auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y"); - b.Add(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y"); + Add(x, y); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -119,16 +189,16 @@ TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { TEST_F(XlaBuilderTest, Call) { XlaBuilder b_call("the_only_to_apply"); - auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); - auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); - b_call.Add(p0, p1); + auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); + Add(p0, p1); TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto one = b.ConstantR0(1); - auto two = b.ConstantR0(2); - b.Add(b.Call(call, {x, y}), b.Call(call, {one, two})); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = ConstantR0(&b, 1); + auto two = ConstantR0(&b, 2); + Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), @@ -137,9 +207,9 @@ TEST_F(XlaBuilderTest, Call) { TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); - b.Add(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + Add(x, y); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); // Expected: @@ -158,9 +228,9 @@ TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); - b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + Add(x, y, /*broadcast_dimensions=*/{0, 1}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); // The binary operation has in-dim broadcast and degenerate broadcast, should @@ -183,9 +253,10 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { XlaBuilder b1("b1"); - auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0"); XlaBuilder builder("main"); - builder.Add(p0, p0); + auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p"); + Add(p, p0); auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -196,8 +267,8 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); - b.Reshape(x, /*new_sizes=*/{6, 35}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + Reshape(x, /*new_sizes=*/{6, 35}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Parameter())); @@ -205,8 +276,8 @@ TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { TEST_F(XlaBuilderTest, ReshapeHasTranspose) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); - b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); @@ -214,25 +285,38 @@ TEST_F(XlaBuilderTest, ReshapeHasTranspose) { TEST_F(XlaBuilderTest, Transpose) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); - b.Transpose(x, /*permutation=*/{1, 0}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + Transpose(x, /*permutation=*/{1, 0}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Transpose(op::Parameter())); } -// TODO(b/65209188): Create a dedicated lowering for Xor. -TEST_F(XlaBuilderTest, Xor) { +TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y"); - b.Xor(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + Add(b.ReportError(InvalidArgument("a test error")), x); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { + XlaBuilder b(TestName()); + StatusOr op(ConstantR0(&b, 1.0)); + Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - LOG(ERROR) << module->ToString(); - EXPECT_THAT(root, - op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)), - op::And(op::Parameter(0), op::Not(op::Parameter(1))))); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { + XlaBuilder b(TestName()); + StatusOr op(InvalidArgument("a test error")); + Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); } } // namespace diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 3f059cac30b5d36ab1d097bf200547533822e3d0..15eeb2ea13607d43c995197f8f0e3c58abd4d94a 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -248,6 +248,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } + if (layout.format() == SPARSE) { + if (!layout.padded_dimensions().empty()) { + return InvalidArgument("Sparse layout has padded dimensions"); + } + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 19e6d288c00a7a541e01390af4946c0caa06615e..eeabf835ac348a5ba55699631188b0e329c98c43 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -2142,6 +2142,7 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { } break; case TUPLE: + case TOKEN: // Nothing to do but assign the shape which is done above. return; default: @@ -2294,6 +2295,9 @@ StatusOr> Literal::CreateFromProto( } return Status::OK(); } + if (piece->subshape().element_type() == TOKEN) { + return Status::OK(); + } CHECK(ShapeUtil::IsArray(piece->subshape())); TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); @@ -2355,7 +2359,6 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : LiteralBase(), shape_(MakeUnique(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); - CHECK_NE(src_buf_ptr, nullptr); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); diff --git a/tensorflow/compiler/xla/overflow_util.h b/tensorflow/compiler/xla/overflow_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8657d3a4bfa992b9ca0619f24923fd4542eed894 --- /dev/null +++ b/tensorflow/compiler/xla/overflow_util.h @@ -0,0 +1,50 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Multiply two nonnegative int64's, returning negative for overflow +inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) { + // Multiply in uint64 rather than int64 since signed overflow is undefined. + // Negative values will wrap around to large unsigned values in the casts + // (see section 4.7 [conv.integral] of the C++14 standard). + const uint64 ux = x; + const uint64 uy = y; + const uint64 uxy = ux * uy; + + // Check if we overflow uint64, using a cheap check if both inputs are small + if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) { + // Ensure nonnegativity. Note that negative numbers will appear "large" + // to the unsigned comparisons above. + CHECK(x >= 0 && y >= 0); + + // Otherwise, detect overflow using a division + if (ux != 0 && uxy / ux != uy) return -1; + } + + // Cast back to signed. Any negative value will signal an error. + return static_cast(uxy); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 445cee1aa7b462f7ae2b6b0771ff57f0c8f3db99..b5ba4e2d429e465649fc1b7acaf19fcb75f6d1ef 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" @@ -174,73 +175,73 @@ StatusOr> CompiledLocalComputation::Execute( GetReplicaCount()); for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule([this, client, replica, &arguments, &shapes_with_layout, - &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - // Transfer arguments in - std::vector scoped_buffers; - scoped_buffers.reserve(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - const Literal& argument = arguments[i]; - const tensorflow::gtl::optional& shape_with_layout = - shapes_with_layout[i]; - - StatusOr pushed; - if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); - } else { - pushed = ToBuffer(client, device_ordinal, argument); - } - if (!pushed.ok()) { - results[replica] = pushed.status(); - return; - } - - scoped_buffers.push_back(std::move(pushed).ValueOrDie()); - } - - // Execute - std::vector argument_buffers; - argument_buffers.reserve(scoped_buffers.size()); - for (auto& buffer : scoped_buffers) { - argument_buffers.push_back(&buffer); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - if (!result_buffer_status.ok()) { - results[replica] = result_buffer_status.status(); - return; - } - - // Transfer result out - results[replica] = client->ShapedBufferToLiteral( - std::move(result_buffer_status).ValueOrDie()); - }); + pool.Schedule( + [this, client, replica, &arguments, &shapes_with_layout, &results] { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " + << device_ordinal; + + // Transfer arguments in + std::vector scoped_buffers; + scoped_buffers.reserve(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + const Literal& argument = arguments[i]; + const tensorflow::gtl::optional& shape_with_layout = + shapes_with_layout[i]; + + StatusOr pushed; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, *relaid); + } else { + pushed = ToBuffer(client, device_ordinal, argument); + } + if (!pushed.ok()) { + results[replica] = pushed.status(); + return; + } + + scoped_buffers.push_back(std::move(pushed).ValueOrDie()); + } + + // Execute + std::vector argument_buffers; + argument_buffers.reserve(scoped_buffers.size()); + for (auto& buffer : scoped_buffers) { + argument_buffers.push_back(&buffer); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + if (!result_buffer_status.ok()) { + results[replica] = result_buffer_status.status(); + return; + } + + // Transfer result out + results[replica] = client->ShapedBufferToLiteral( + std::move(result_buffer_status).ValueOrDie()); + }); } } @@ -341,14 +342,11 @@ StatusOr LocalComputationBuilder::Build() { LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - return builder_.Parameter(parameter_number, shape, name); + return xla::Parameter(&builder_, parameter_number, shape, name); } -std::unique_ptr LocalComputationBuilder::GetShape( - const LocalOp& operand) { - auto result = MakeUnique(); - *result = builder_.GetShape(operand.op()).ValueOrDie(); - return result; +StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { + return builder_.GetShape(operand.op()); } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -357,72 +355,70 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { } LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { - return builder_.Infeed(shape); + return xla::Infeed(&builder_, shape); } void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand.op(), shape, outfeed_config); + xla::Outfeed(operand.op(), shape, outfeed_config); } LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { - return builder_.ConstantLiteral(literal); + return xla::ConstantLiteral(&builder_, literal); } LocalOp LocalComputationBuilder::Broadcast( const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand.op(), broadcast_sizes); + return xla::Broadcast(operand.op(), broadcast_sizes); } LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config) { - return builder_.Pad(operand.op(), padding_value.op(), padding_config); + return xla::Pad(operand.op(), padding_value.op(), padding_config); } LocalOp LocalComputationBuilder::Reshape( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand.op(), dimensions, new_sizes); + return xla::Reshape(operand.op(), dimensions, new_sizes); } LocalOp LocalComputationBuilder::Collapse( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand.op(), dimensions); + return xla::Collapse(operand.op(), dimensions); } LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { - return builder_.CrossReplicaSum(operand.op()); + return xla::CrossReplicaSum(operand.op()); } LocalOp LocalComputationBuilder::Slice( const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand.op(), start_indices, limit_indices, strides); + return xla::Slice(operand.op(), start_indices, limit_indices, strides); } LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { - return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, - dimno); + return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } LocalOp LocalComputationBuilder::DynamicSlice( const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); + return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } LocalOp LocalComputationBuilder::DynamicUpdateSlice( const LocalOp& operand, const LocalOp& update, const LocalOp& start_indices) { - return builder_.DynamicUpdateSlice(operand.op(), update.op(), - start_indices.op()); + return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } LocalOp LocalComputationBuilder::ConcatInDim( @@ -432,7 +428,7 @@ LocalOp LocalComputationBuilder::ConcatInDim( for (const auto& op : operands) { xla_ops.push_back(op.op()); } - return builder_.ConcatInDim(xla_ops, dimension); + return xla::ConcatInDim(&builder_, xla_ops, dimension); } LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( @@ -442,7 +438,7 @@ LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const LocalOp& source, const LocalOp& init_value, const LocalComputation& scatter) { - return builder_.SelectAndScatterWithGeneralPadding( + return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } @@ -455,22 +451,22 @@ LocalOp LocalComputationBuilder::Tuple( xla_ops.push_back(op.op()); } - return builder_.Tuple(xla_ops); + return xla::Tuple(&builder_, xla_ops); } LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data.op(), index); + return xla::GetTupleElement(tuple_data.op(), index); } LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { - return builder_.Dot(lhs.op(), rhs.op()); + return xla::Dot(lhs.op(), rhs.op()); } LocalOp LocalComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); + return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } LocalOp LocalComputationBuilder::ConvGeneralDilated( @@ -480,14 +476,13 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, - padding, lhs_dilation, rhs_dilation, - dimension_numbers); + return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, + lhs_dilation, rhs_dilation, dimension_numbers); } LocalOp LocalComputationBuilder::ConvertElementType( const LocalOp& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand.op(), new_element_type); + return xla::ConvertElementType(operand.op(), new_element_type); } LocalOp LocalComputationBuilder::Call( @@ -498,46 +493,39 @@ LocalOp LocalComputationBuilder::Call( for (const auto& op : operands) { xla_ops.push_back(op.op()); } - return builder_.Call(local_computation.computation(), xla_ops); + return xla::Call(&builder_, local_computation.computation(), xla_ops); } LocalOp LocalComputationBuilder::Transpose( const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand.op(), permutation); + return xla::Transpose(operand.op(), permutation); } LocalOp LocalComputationBuilder::Rev( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand.op(), dimensions); + return xla::Rev(operand.op(), dimensions); } LocalOp LocalComputationBuilder::Map( tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { + tensorflow::gtl::ArraySlice dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { xla_ops.push_back(op.op()); } - std::vector static_xla_ops; - static_xla_ops.reserve(static_operands.size()); - for (const auto& op : static_operands) { - static_xla_ops.push_back(op.op()); - } - - return builder_.Map(xla_ops, local_computation.computation(), dimensions, - static_xla_ops); + return xla::Map(&builder_, xla_ops, local_computation.computation(), + dimensions); } LocalOp LocalComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand.op(), init_value.op(), - local_computation.computation(), dimensions_to_reduce); + return xla::Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( @@ -546,7 +534,7 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - return builder_.ReduceWindowWithGeneralPadding( + return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), window_dimensions, window_strides, padding); } @@ -554,27 +542,27 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, const Shape& shape) { - return builder_.RngNormal(mu.op(), sigma.op(), shape); + return xla::RngNormal(mu.op(), sigma.op(), shape); } LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape) { - return builder_.RngUniform(a.op(), b.op(), shape); + return xla::RngUniform(a.op(), b.op(), shape); } LocalOp LocalComputationBuilder::While(const LocalComputation& condition, const LocalComputation& body, const LocalOp& init) { - return builder_.While(condition.computation(), body.computation(), init.op()); + return xla::While(condition.computation(), body.computation(), init.op()); } LocalOp LocalComputationBuilder::Conditional( const LocalOp& predicate, const LocalOp& true_operand, const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional( - predicate.op(), true_operand.op(), true_computation.computation(), - false_operand.op(), false_computation.computation()); + return xla::Conditional(predicate.op(), true_operand.op(), + true_computation.computation(), false_operand.op(), + false_computation.computation()); } StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { @@ -590,7 +578,7 @@ StatusOr LocalComputationBuilder::BuildConstantSubGraph( #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ - return builder_.method_name args; \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -624,6 +612,7 @@ _FORWARD_BINOP(Max) _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) +_FORWARD_BINOP(Xor) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 0da3964676e9c6729229686f38bb05c8b2427bff..e920f8aecd6cfc6fd4c965b1cc9eceb36b2d7371 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -187,7 +187,7 @@ class LocalComputationBuilder { LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - std::unique_ptr GetShape(const LocalOp& operand); + StatusOr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); @@ -270,8 +270,7 @@ class LocalComputationBuilder { LocalOp Map(tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + tensorflow::gtl::ArraySlice dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, @@ -333,6 +332,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) + _FORWARD_BINOP(Xor) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 477df6fde25d0db760e08df9d335bd12e31ccb55..76e9e637cd45509ec443be092fd9934db1a9653f 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -988,6 +988,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Min; %unignore xla::swig::LocalComputationBuilder::And; %unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Xor; %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c025127c3cf1871d4def1297ed36c046cae61d4b..abb97d0c6fae515b8f1c11c7df48299f05fc9fad 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -123,6 +123,7 @@ _BINARY_OPS = [ 'Min', 'And', 'Or', + 'Xor', 'Pow', ] @@ -257,9 +258,12 @@ class Shape(object): self._dimensions == other._dimensions and self._minor_to_major == other._minor_to_major) + def __ne__(self, other): + return not self == other + def __repr__(self): return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' - '_is_tuple={!r}), _minor_to_major={!r}').format( + '_is_tuple={!r}, _minor_to_major={!r})').format( self._dtype, self._dimensions, self._is_tuple, self._minor_to_major) @@ -905,20 +909,19 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.c_local_computation, operands) - def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. Args: operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. - static_operands: auxiliary arguments passed to the applied computation. Returns: A LocalOp representing the added Map op. """ return self._client.Map(operands, computation_to_apply.c_local_computation, - dimensions, static_operands) + dimensions) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 71e1d60a4e23dbfef333223c396e109533da9365..0564ddcb85ee3952f82649687e79a864999baf2c 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -157,6 +157,13 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayBool([True, True, False, False]))) self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + def testBooleanXor(self): + c = self._NewComputation() + c.Xor( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + def testSum2DF32(self): c = self._NewComputation() c.Add( @@ -1168,14 +1175,6 @@ class EmbeddedComputationsTest(LocalComputationTest): self._CreateBinaryDivF64Computation(), [0]) self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) - def DISABLED_testMapWithStaticOperands(self): - c = self._NewComputation() - factor = c.ConstantF32Scalar(3.0) - c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], - self._CreateMulF32ByParamComputation(), [0], - static_operands=[factor]) - self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) - def testSelectAndScatterF32(self): c = self._NewComputation() c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index d7dd9786a2bbde2d18ae81a9a9d4cc4b2cc38411..f8414468bd9e0a9faf0072c47d94d12ab11b908d 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -85,13 +85,13 @@ TEST_F(GRPCClientTestBase, ItsAlive) { TEST_F(GRPCClientTestBase, AxpyTenValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto y = builder.ConstantR1( - {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); - auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = ConstantR1( + &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = Mul(alpha, x); + Add(ax, y); std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 8a1d1bf73d51d81f6a9cf353c0bd0591231f5225..fe99f700d23dbab799ba011b705c59d6ef7a2e52 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -32,6 +32,7 @@ tf_proto_library_py( name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. srcs = ["hlo.proto"], visibility = ["//visibility:public"], + deps = ["//tensorflow/compiler/xla:xla_data_proto_py"], ) xla_proto_library( @@ -1951,6 +1952,7 @@ cc_library( hdrs = ["tuple_points_to_analysis.h"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", "//tensorflow/compiler/xla:shape_tree", @@ -2093,6 +2095,7 @@ cc_library( hdrs = ["hlo_verifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", @@ -2399,6 +2402,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 1fc8fb9b6994db78fe3aa06e1ea790decfce7b97..1ddeb27e4041df22bd3d0ec200bcddbd09937e01 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -50,20 +50,15 @@ namespace { namespace m = match; -// Returns whether operand is a literal with the given value. -bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAll(value); -} - bool IsAll(const HloInstruction* op, int8 value) { - if (IsLiteralWithValue(op, value)) { - return true; - } - if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { - return true; + switch (op->opcode()) { + case HloOpcode::kBroadcast: + return IsAll(op->operand(0), value); + case HloOpcode::kConstant: + return op->literal().IsAll(value); + default: + return false; } - return false; } // Returns whether the given transpose produces a result which is bit-wise @@ -75,21 +70,22 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { transpose->dimensions()); } -// Returns true if the given reshape produces a result which is bit-wise +// Returns true if the given reshape/copy produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. // // This function is conservative -- even if this function returns false, the // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. -bool ReshapeIsBitcast( - const HloInstruction* reshape, +bool ReshapeOrCopyIsBitcast( + const HloInstruction* instr, const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { - CHECK_EQ(HloOpcode::kReshape, reshape->opcode()); + CHECK(HloOpcode::kReshape == instr->opcode() || + HloOpcode::kCopy == instr->opcode()); - const HloInstruction* operand = reshape->operand(0); + const HloInstruction* operand = instr->operand(0); // Can't insert bitcasts if the compiler used a memory layout which isn't // compatible. - return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && - valid_bitcast_callback(operand->shape(), reshape->shape()); + return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) && + valid_bitcast_callback(operand->shape(), instr->shape()); } // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -159,9 +155,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMap(HloInstruction* map) override; - Status HandleMaximum(HloInstruction* maximum) override; - Status HandleMinimum(HloInstruction* minimum) override; - // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -200,8 +193,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Helper method to perform and add reduction in a single dimension. HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { - HloInstruction* zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction* zero = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::Zero(hlo->shape().element_type()).CloneToUnique())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -433,7 +427,15 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); + if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) { + return Status::OK(); + } + + if (is_layout_sensitive_ && + ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + ReplaceWithBitcast(copy); + } + return Status::OK(); } @@ -528,6 +530,10 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, BuildTupleConstant(computation_, constant->literal())); } + if (constant->shape().element_type() == TOKEN) { + return Status::OK(); + } + // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { @@ -563,6 +569,14 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { return Status::OK(); } +namespace { +template +Status InvertConstant(const HloInstruction& constant, Literal* result) { + return result->Populate([&](tensorflow::gtl::ArraySlice indices) { + return T{1.0} / constant.literal().Get(indices); + }); +} +} // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { Shape* shape; @@ -624,14 +638,31 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (Backends can do this transformation, but generally only if the constant is // a scalar.) if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - HloInstruction* one = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(a->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = computation_->AddInstruction( - HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kMultiply, a, inverse)); + Literal new_literal(b->shape()); + switch (b->shape().element_type()) { + case F16: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case F32: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case BF16: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case F64: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case C64: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + default: + return Status::OK(); + } + auto inverse = computation_->AddInstruction( + HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); + return ReplaceInstruction(divide, new_divide); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) @@ -651,18 +682,18 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, b, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a, b_times_c)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); + return ReplaceInstruction(divide, new_divide); } // A / (B / C) => (A*C) / B if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, a, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a_times_c, b)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); + return ReplaceInstruction(divide, new_divide); } return Status::OK(); @@ -1221,9 +1252,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, switch (instruction->opcode()) { case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSort: case HloOpcode::kTranspose: return true; + case HloOpcode::kSort: + return (!ShapeUtil::IsTuple(instruction->shape())); default: return false; } @@ -1672,7 +1704,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Make this a bitcast if possible. if (is_layout_sensitive_ && - ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { + ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { ReplaceWithBitcast(reshape); return Status::OK(); } @@ -2065,10 +2097,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( convolution, HloInstruction::CreateBroadcast( convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(convolution->shape().element_type(), {}), - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::Zero(convolution->shape().element_type()) + .CloneToUnique())), {})); } const auto& window = convolution->window(); @@ -2240,68 +2271,6 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { return ReplaceWithNewInstruction(map, std::move(clone)); } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { - // Match the following tree: - // min_operand operand - // \ / - // max_operand min - // \ / - // max - // where max_operand and min_operand are scalar constants. - { - HloInstruction* min; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMinimum, maximum, - /*matching_operand=*/&min, - /*other_operand=*/&max_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, min, - /*matching_operand=*/&min_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum, - max_operand)) { - return Status::OK(); - } - } - - return Status::OK(); -} - -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { - // Match the following tree: - // max_operand operand - // \ / - // min_operand max - // \ / - // min - // where max_operand and min_operand are scalar constants. - { - HloInstruction* max; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMaximum, minimum, - /*matching_operand=*/&max, - /*other_operand=*/&min_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, max, - /*matching_operand=*/&max_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max, - max_operand)) { - return Status::OK(); - } - } - - return Status::OK(); -} - StatusOr AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 2605b0488cb7c6850746df94c4ab05d6b5d35de5..b733f6f59eb028b2dff921722c462441251772fe 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -201,8 +201,11 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation)); + builder.AddInstruction(HloInstruction::CreateMap( + r2f32, + {param0, builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, zero, {}))}, + add_computation)); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -211,7 +214,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, zero)); + EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { @@ -367,17 +370,16 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r2f32, "param1")); HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, r2f32, "param2")); HloInstruction* param3 = builder.AddInstruction( - HloInstruction::CreateParameter(3, r0f32, "param3")); + HloInstruction::CreateParameter(3, r2f32, "param3")); HloInstruction* div0 = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1)); HloInstruction* div1 = builder.AddInstruction( @@ -398,8 +400,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); - EXPECT_TRUE( - ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32)); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -459,7 +459,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -467,7 +466,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* param2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2)); builder.AddInstruction( @@ -484,14 +483,9 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); - - const HloInstruction* negate = - computation->root_instruction()->operand(1)->operand(1); - const Shape& negate_shape = negate->shape(); - EXPECT_EQ(0, negate_shape.dimensions_size()); } -// A / Const => A * (1 / Const) +// A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); @@ -510,20 +504,19 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Divide(op::Constant(), constant))); + op::Multiply(param0, op::Constant())); } // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, @@ -540,15 +533,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { - Shape r0c64 = ShapeUtil::MakeShape(C64, {}); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1c64, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0c64, "param1")); + HloInstruction::CreateParameter(1, r1c64, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0c64, "param2")); + HloInstruction::CreateParameter(2, r1c64, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, @@ -1159,6 +1151,33 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { EXPECT_THAT(computation->root_instruction(), param0); } +TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), "param")); + *param->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3}); + HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); + *copy->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({1, 2, 0, 3}); + auto computation = module().AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + + AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie()); + // Verify that the copy is not replaced. + EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + + AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, + bitcasting_callback()); + ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie()); + // Verify that the copy is replaced. + EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); +} + // Test that unary concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); @@ -1389,33 +1408,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -// Regression test for a bug in the reshape sinking transformation, where -// moving a reshape to a scalar led to a crash. -TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); - HloInstruction* reshape = builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); -} - // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { @@ -2076,160 +2068,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { EXPECT_EQ("NO_CHANGE", build_and_simplify()); } -// Test that max(min(A, x), y) is transformed to clamp(y, A, x) -TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, param0, min_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Minimum(param0, min_value), max_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar -// values. -TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for -// broadcasted scalar values. -TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to -// clamp(non-constant1, A, non-constant2) -TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); -} - -// Test that min(f(max(A, constant1)), constant2) is not transformed to -// clamp(constant1, A, constant2) -TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* fmax = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); - builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, fmax, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); -} - // Test that slice(broadcast(/*scalar value*/)) simplifies to a single // broadcast. TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 8f1d2f0804960b04dbff4c990c356589a609ce8d..ff6d5027efba813042af65a0e50e172cc0a99ff8 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -85,9 +85,9 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { - for (const auto& index : root_changes_it->second) { + for (const auto& entry : root_changes_it->second) { for (const HloValue* value : - dataflow_->GetValueSet(root, index).values()) { + dataflow_->GetValueSet(root, entry.second).values()) { changed_root_buffers.insert(value); } } @@ -204,6 +204,12 @@ void BFloat16Propagation::DetermineWhileComputationsPrecision( bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { + // If the subshape isn't floating point then none of the users will be BF16. + const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index); + if (subshape.element_type() != BF16 && subshape.element_type() != F32) { + return false; + } + auto& value_set = dataflow_->GetValueSet(&hlo, index); for (const HloValue* value : value_set.values()) { if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { @@ -257,23 +263,34 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // If the op propagates precision and it outputs a BF16, then it's OK to // supply BF16 also as the input. In the backward pass, the users shapes // should have already been processed. - PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; - if (use.instruction->opcode() == HloOpcode::kTuple || - (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && - ShapeUtil::IsTuple(use.instruction->shape()))) { - ShapeIndex use_output_index{use.operand_number}; - for (int64 i : use.operand_index) { - use_output_index.push_back(i); - } - user_output_type = - OutputTypeAfterChange(use.instruction, use_output_index); - } else { - user_output_type = OutputTypeAfterChange(use.instruction, {}); - } if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( - *use.instruction, use.operand_number) && - user_output_type == BF16) { - continue; + *use.instruction, use.operand_number)) { + if (use.instruction->opcode() == HloOpcode::kTuple || + (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + ShapeUtil::IsTuple(use.instruction->shape()))) { + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) { + ShapeIndex use_output_index; + for (int64 i = 1; i < use.operand_index.size(); ++i) { + use_output_index.push_back(use.operand_index[i]); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else { + if (OutputTypeAfterChange(use.instruction, use.operand_index) == + BF16) { + continue; + } + } } return false; } @@ -368,6 +385,7 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output( if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && hlo->opcode() != HloOpcode::kTuple && hlo->opcode() != HloOpcode::kGetTupleElement && + hlo->opcode() != HloOpcode::kDomain && hlo->shape().element_type() != BF16) { for (int64 i = 0; i < hlo->operand_count(); ++i) { if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, @@ -559,7 +577,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { - std::list computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); tensorflow::gtl::FlatSet resolved; for (auto comp_it = computations_topological_order.rbegin(); @@ -742,7 +760,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); - std::list computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); // The first step is a forward pass (parameters to root), where we determine // the potential candidate instructions to use bfloat16 in the outputs that @@ -784,9 +802,8 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // Apply the changes in changes_to_bf16_. for (auto& change : changes_to_bf16_) { - auto shape = change.first->mutable_shape(); - for (const auto& index : change.second) { - auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + for (const auto& entry : change.second) { + auto subshape = entry.first; CHECK_EQ(subshape->element_type(), F32); subshape->set_element_type(BF16); changed_ = true; @@ -815,8 +832,8 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { PrimitiveType BFloat16Propagation::OutputTypeAfterChange( HloInstruction* hlo, const ShapeIndex& index) const { - PrimitiveType type_on_hlo = - ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index); + const PrimitiveType type_on_hlo = subshape->element_type(); if (type_on_hlo != F32) { return type_on_hlo; } @@ -824,7 +841,7 @@ PrimitiveType BFloat16Propagation::OutputTypeAfterChange( if (it == changes_to_bf16_.end()) { return type_on_hlo; } - return ContainsKey(it->second, index) ? BF16 : F32; + return ContainsKey(it->second, subshape) ? BF16 : F32; } PrimitiveType BFloat16Propagation::ValueTypeAfterChange( @@ -838,14 +855,16 @@ void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { if (target_type == BF16) { auto& entry = changes_to_bf16_[hlo]; - entry.insert(index); + entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index), + index); } else { CHECK_EQ(target_type, F32); auto it = changes_to_bf16_.find(hlo); if (it == changes_to_bf16_.end()) { return; } - it->second.erase(index); + it->second.erase( + ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index)); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index de0355ddfca127753f90d1899b424a8e77c9b291..02b8cad089dd8465b7af5c1014e37b77ded6949d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -194,17 +194,11 @@ class BFloat16Propagation : public HloPassInterface { // are subject to further adjustment, then finally applied to the HLOs. This // avoids setting changed_ to true but all changes are reverted during // adjustment. - struct IndexHasher { - int64 operator()(const ShapeIndex& index) const { - int64 hash = 0; - for (int64 i : index) { - hash = tensorflow::Hash64Combine(hash, std::hash()(i)); - } - return hash; - } - }; + // + // For each HloInstruction, changes_to_bf16_ stores the affected buffers in + // the output as a map from in-place pointers to subshapes to shape indices. tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> changes_to_bf16_; // Whether the last processed HLO module has been changed by this pass. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5e1499ee6b6ef397f95f7ed29e808d530777bd07..560910cc5ffbf74737b6f025f7da2928c9cd621b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -150,11 +150,11 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(0)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)), + dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(1)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)), + dot->operand(1)->literal())); } // Tests that BF16 can be propagated through nested tuples. @@ -742,4 +742,89 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { EXPECT_EQ(add1->shape().element_type(), BF16); } +TEST_F(BFloat16PropagationTest, TupleDomain) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* a_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1})); + HloInstruction* b_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1})); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans})); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), root); + + // test BF16 propagated through domain + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(), + BF16); + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(), + BF16); + + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_TRUE(OutputsBF16(a_gte)); + EXPECT_TRUE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(a)); + EXPECT_FALSE(OutputsBF16(b)); +} + +// Tests that bf16 is not propagated through a domain in case its input cannot +// be propagated. In the case below the input of the domain is the parameter +// tuple which cannot be propagated, so the domain instruction is not propagated +// either. +TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* a_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); + HloInstruction* b_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_FALSE(OutputsBF16(a_gte)); + EXPECT_FALSE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(domain)); + EXPECT_FALSE(OutputsBF16(param)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 07b4b14b5ec1bdbc01345091105df69368b0b2fb..8595afca7e735528d9ef29a323696c0661fe971c 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -25,6 +25,7 @@ bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -43,6 +44,7 @@ bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -81,6 +83,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -92,6 +95,9 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; + case HloOpcode::kBitcast: + return hlo.shape().element_type() == + hlo.operand(0)->shape().element_type(); case HloOpcode::kDynamicSlice: return operand_index == 0; case HloOpcode::kDynamicUpdateSlice: diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index efa4696130ffeff669b0d674438a45c5a9d48ef2..28b5a5784ff7f5d0b7fd412d1c50f3025f11bb81 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1874,11 +1874,15 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); - auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = + builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, "")); + auto infeed_data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r0s32, infeed, 0)); auto cond0 = module->AddEmbeddedComputation(build_cond()); auto body0 = module->AddEmbeddedComputation(build_body()); auto while0 = builder.AddInstruction( - HloInstruction::CreateWhile(r0s32, cond0, body0, infeed)); + HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data)); auto cond1 = module->AddEmbeddedComputation(build_cond()); auto body1 = module->AddEmbeddedComputation(build_body()); @@ -1909,8 +1913,8 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = {infeed, while0, while1, zero, - add, while2, tuple}; + sequence[module->entry_computation()] = { + token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; TF_ASSERT_OK_AND_ASSIGN( auto assignment, BufferAssigner::Run( diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 738d00881dd057fc13c115006c15e8f5b6d14a1d..924348c870b9ca3d86af560a0c8359af7220427e 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -148,14 +148,16 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({})); outfeeder.AddInstruction( - HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/"")); + HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/"")); auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build()); HloComputation::Builder outer(TestName() + ".outer"); outer.AddInstruction(HloInstruction::CreateCall( - ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation)); + outfeed_computation->root_instruction()->shape(), /*operands=*/{}, + outfeed_computation)); module->AddEntryComputation(outer.Build()); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 0dceed853dcbae211657f00433866cfe10c51fc7..6b3b9820f09803c8a04504e6c35c22de51abf04b 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -35,6 +35,13 @@ Compiler::ComputeBackendConfigs(const HloInstruction& hlo, return {}; } +std::unique_ptr +Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return nullptr; +} + // Define a default version where metadata is not used. StatusOr>> Compiler::CompileAheadOfTime( diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d1144f97bb2ab29d3d18f3b3f65a38af46e68dd1..99abb9bae32b35652e84cddc7c38dbd97ecb5006 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -179,6 +179,16 @@ class Compiler { ComputeBackendConfigs(const HloInstruction& hlo, se::StreamExecutor* executor) const; + // Returns the backend configuration that the backend chooses by default for + // the given HLO. Returns no configuration if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::unique_ptr + ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 868348547d9f5cbdc7576c7fc0697d72c3a3e557..c38719d50efaf7e1b95b5ed2cf3030f9bfdfe57f 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -144,8 +144,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); - false_computation->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = + false_computation->AddInstruction(HloInstruction::CreateAfterAll({})); + false_computation->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index e0ce2e3555e7746d6df212123fe1f968937cceed..ab3d846403ef264cd732a9c01d524cd4ccf65c38 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -1093,12 +1093,13 @@ void MaybeDumpModule(const string& message, const HloModule& module) { } // namespace Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { + const HloOrdering& ordering, HloModule* module, + const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer) { MaybeDumpModule("after adding copies to resolve interference", *module); TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); + HloAliasAnalysis::Run(module, fusion_can_share_buffer)); CopyRemover copy_remover(*alias_analysis, ordering, module); XLA_VLOG_LINES(3, copy_remover.ToString()); @@ -1106,7 +1107,6 @@ Status RemoveUnnecessaryCopies( for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id()) && instruction->CopyElisionAllowed()) { TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); } @@ -1150,16 +1150,13 @@ StatusOr CopyInsertion::Run(HloModule* module) { "Call graph must be flattened before copy insertion."); } - // Gather Ids of existing kCopy instructions in the module. We avoid removing - // these copies (except via DCE in TupleSimplifier) because they may have been - // added for reasons not considered by copy insertion (eg, layout assignment). - // Instruction id is used instead of HloInstruction* because the pointer - // values may be recycled. - tensorflow::gtl::FlatSet existing_copies; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - existing_copies.insert(instruction->unique_id()); + int64 num_existing_copies = 0; + if (VLOG_IS_ON(1)) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + ++num_existing_copies; + } } } } @@ -1179,8 +1176,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(ordering, existing_copies, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1201,7 +1197,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { } } } - VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies; VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 0d7b3c20f982cae21e5160fe5be20c85bf940ed7..e1973db928423cb4bbad00fe34329f731b23ea09 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -48,6 +47,15 @@ class CopyInsertion : public HloPassInterface { public: tensorflow::StringPiece name() const override { return "copy-insertion"; } + // fusion_can_share_buffer: backend specific function that decides whether a + // fusion can share buffer with its operand. + // + // TODO(b/80315712): Find a better way to tell whether a fusion can share + // buffer. + CopyInsertion(const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr) + : fusion_can_share_buffer_(fusion_can_share_buffer) {} + // Run the pass on the given module. Returns whether the module was changed // (copies were inserted). StatusOr Run(HloModule* module) override; @@ -62,14 +70,20 @@ class CopyInsertion : public HloPassInterface { // // TODO(b/62548313): Remove this when buffer assignment is module-scoped. static StatusOr AddCopiesForBufferAssignment(HloModule* module); + + private: + // Backend specific function that decides whether a fusion can share buffer + // with its operand. + HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; }; // Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. +// live range interference. Only copy instructions that are eligible for +// copy elision are considered for removal. Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module); + const HloOrdering& ordering, HloModule* module, + const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr); } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index ed1a50f516ee23e0f034bf5c2ed15fac7a70c3cc..7ae8799b612449ecc3c45123e769aac817d12058 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) { } TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { - // Verify that an kCopy instructions which exist in the pass before + // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); - HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}))); + auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); + Layout reversed_layout = + LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); + Shape copy_shape = constant->shape(); + *copy_shape.mutable_layout() = reversed_layout; + HloInstruction* copy_1 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); - HloInstruction* add_copy = builder.AddInstruction( - HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + builder.AddInstruction( + HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add)); module->AddEntryComputation(builder.Build()); @@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); + EXPECT_EQ(module->entry_computation()->root_instruction(), add); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -1605,8 +1610,8 @@ HloModule TokensShouldNotBeCopied %constant.1 = s32[] constant(1) %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %generate-token = token[] generate-token(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) } %Cond (param: (s32[], token[])) -> pred[] { @@ -1619,7 +1624,7 @@ HloModule TokensShouldNotBeCopied ENTRY %TokensShouldNotBeCopied () -> s32[] { %one = s32[] constant(1) %negative_one = s32[] negate(%one) - %init_token = token[] generate-token() + %init_token = token[] after-all() %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index b703be0f39e2032bc58479f0b957f9d8b01a77c3..2c3eb1ae367ffe1de93c6fc8f4efdc6d69964e10 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -53,29 +53,6 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) -cc_library( - name = "external_constant_pool", - srcs = ["external_constant_pool.cc"], - hdrs = ["external_constant_pool.h"], - deps = [ - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "external_constant_pool_test", - srcs = ["external_constant_pool_test.cc"], - deps = [ - ":external_constant_pool", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], @@ -175,7 +152,6 @@ cc_library( ":cpu_runtime", ":custom_call_target_registry", ":disassembler", - ":external_constant_pool", ":orc_jit_memory_mapper", ":runtime_fp16", ":runtime_conv2d", @@ -256,7 +232,6 @@ cc_library( ":cpu_options", ":cpu_runtime", ":dot_op_emitter", - ":external_constant_pool", ":ir_emission_utils", ":ir_function", ":parallel_loop_emitter", @@ -273,6 +248,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d039132535071661d047579587385210719fede3..55962ba70d213939ccb49cad3bdd75395cc4eaa5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -269,6 +269,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, /*enable_dot_strength_reduction=*/false); + pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -303,15 +304,19 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_device_entry_computation_layout(), - &target_machine_features); + module->mutable_entry_computation_layout(), &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); - pipeline.AddPass(/*is_layout_sensitive=*/true); + { + auto& pass = pipeline.AddPass>( + "after layout assignement"); + pass.AddPass>( + /*is_layout_sensitive=*/true, + [](const Shape&, const Shape&) { return true; }, + /*enable_dot_strength_reduction=*/false); + pass.AddPass(); + pass.AddPass(/*is_layout_sensitive=*/true); + } pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = @@ -579,7 +584,7 @@ StatusOr> CpuCompiler::RunBackend( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, jit->external_constant_pool()); + &target_machine_features); for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { @@ -766,8 +771,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, - /*external_constant_pool=*/nullptr); + &target_machine_features); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index cf43b74c699ca8cbbef11a0abbaf4d69476f5d77..1093559892ddb9c238fd9c1f7e3d419ec7022776 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -206,8 +206,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( - /*on_host_shape=*/host_result_shape(), - /*on_device_shape=*/host_result_shape(), run_options->allocator(), + /*on_host_shape=*/result_shape(), + /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); // Move OwningDeviceMemory values which contain the array(s) of the result diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 97e10a89a209c057685709e7a5034052ff4376ed..750310c633286aa8f964c9ae5dcf847f2dc0557c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -501,8 +501,8 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { HloInstruction* exp = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get()))); module->AddEntryComputation(builder.Build()); @@ -525,8 +525,8 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get()))); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index d97802ee45d6add3c466577d7624d9ca74e2f380..b877b295814a7e13569a1837ed3e1787f2fc3f56 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -160,9 +160,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int32 size_32 = static_cast(size); CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); - Status s = - TransferBufferToDevice(executor, /*size=*/size, - /*source=*/source, queued_buffer->device_memory()); + Status s = executor->SynchronousMemcpyH2D( + /*host_src=*/source, /*size=*/size, queued_buffer->device_memory()); if (!s.ok()) { queued_buffer->Done(s); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index e8b205051e2828b8f1d3ecd2161ae9d53d3f1796..58228180ca55ede50c8579bbd73cfdfffc07e208 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1380,7 +1380,7 @@ Status DotOpEmitter::Emit() { // the rhs and lhs indexes with the reduction dimensions removed. The terms // from the rhs index are the lower dimensions in the index so we add them // first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(lhs_index.GetType()); for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); @@ -1404,10 +1404,13 @@ Status DotOpEmitter::Emit() { Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; + // Use the same index_type for all tensor accesses in the same kernel. + llvm::Type* index_type = ir_builder_->getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); llvm::Value* lhs_value = - lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + lhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); llvm::Value* rhs_value = - rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + rhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { #define REAL(x) ir_builder_->CreateExtractValue(x, {0}) #define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) @@ -1425,7 +1428,8 @@ Status DotOpEmitter::EmitScalarDot() { } else { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } - target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); + target_array_.EmitWriteArrayElement(/*index=*/element_index, result, + ir_builder_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc deleted file mode 100644 index c56286559158758ca6db5ae097729286bde346f0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" - -namespace xla { -namespace cpu { -void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, - int64 alignment) { - CHECK(!ShapeUtil::IsTuple(literal.shape())); - CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); - CHECK(entries_.find(name) == entries_.end()); - - const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); - void* raw_pointer = tensorflow::port::AlignedMalloc( - literal_size, std::max(alignment, sizeof(void*))); - CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size - << " bytes with alignment of " << alignment; - - std::memcpy(raw_pointer, literal.untyped_data(), literal_size); - entries_.emplace(std::move(name), static_cast(raw_pointer)); -} - -const uint8* ExternalConstantPool::Find(const string& name) { - auto it = entries_.find(name); - return it == entries_.end() ? nullptr : it->second.get(); -} -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h deleted file mode 100644 index 0677f5f0b58005079890052a426e5f48c5d09ed1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ - -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/platform/mem.h" - -namespace xla { -namespace cpu { -// An ExternalConstantPool maintains a set of constants kept external to -// generated LLVM IR. These constants are accessed from the IR via globals with -// extern linkage. This current incarnation of ExternalConstantPool only -// supports the JIT CPU backend; the AOT backend is not supported. -// -// Implementation-wise, this is a simple wrapper around a map of strings to byte -// buffers. This simply implementation works in a JIT scenario. This class -// will have to become smarter if we decide to support external constant pools -// on AOT compiles in the future. -class ExternalConstantPool { - public: - // Inserts a buffer with the contents of `literal` into the constant pool with - // the name `name`. It is an error to try to insert two constants with the - // same `name` into the same constant pool. The buffer for literal is aligned - // to `aligment` bytes, and `alignment` must be a power of 2. - // - // The constant pool copies out the contents of `literal` into a buffer it - // owns -- it does not keep pointers to `literal`, or to memory owned by - // `literal`. - void Insert(string name, const LiteralSlice& literal, int64 alignment); - - // Find the constant with name `name` in this constant pool. If there isn't - // such constant, return nullptr. - const uint8* Find(const string& name); - - private: - // We need to `AlignedFree` pointers allocated into `entries_` since we - // allocate them with `AlignedMalloc`. - struct FreeDeleter { - void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); } - }; - - tensorflow::gtl::FlatMap> - entries_; -}; -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc deleted file mode 100644 index 9290a4e5dfc03ddb86e9d82f1f0f4f9a8ceebb88..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace cpu { -namespace { -class ExternalConstantPoolTest : public ::testing::Test {}; - -template -T GetFromBuffer(const uint8* buffer, int64 index) { - T result; - std::memcpy(&result, buffer + index * sizeof(T), sizeof(T)); - return result; -} - -TEST(ExternalConstantPoolTest, Basic) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer(constant, 0), 1); - EXPECT_EQ(GetFromBuffer(constant, 1), 2); - EXPECT_EQ(GetFromBuffer(constant, 2), 3); - EXPECT_EQ(GetFromBuffer(constant, 3), 4); - - EXPECT_EQ(constant_pool.Find("name-1"), nullptr); -} - -TEST(ExternalConstantPoolTest, RowMinorLayout) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2WithLayout( - {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer(constant, 0), 1); - EXPECT_EQ(GetFromBuffer(constant, 1), 3); - EXPECT_EQ(GetFromBuffer(constant, 2), 2); - EXPECT_EQ(GetFromBuffer(constant, 3), 4); -} - -TEST(ExternalConstantPoolTest, Alignment) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - - for (int i = 0; i < 8; i++) { - int64 alignment = 1 << i; - string name = tensorflow::strings::StrCat("name-", i); - - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert(name, *literal, alignment); - - const uint8* constant = constant_pool.Find(name); - ASSERT_NE(constant, nullptr); - EXPECT_EQ(reinterpret_cast(constant) % alignment, 0); - } -} - -} // namespace -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 758b8c62b4800215caae82208454ac971807f6eb..6b66a4b0b7cef0058a761801815606b9440016cf 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -48,6 +48,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -83,8 +85,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine_features, - ExternalConstantPool* external_constant_pool) + const TargetMachineFeatures* target_machine_features) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -94,8 +95,7 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(*target_machine_features), - external_constant_pool_(external_constant_pool) { + target_machine_features_(*target_machine_features) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); @@ -161,45 +161,18 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { } llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::Constant* result; - - // We avoid creating large constants in the LLVM IR since LLVM is not - // efficient for large constant arrays. We still emit "small enough" constant - // arrays into the Ir, in the off chance the LLVM optimizer can do something - // interesting with it. - // - // TODO(b/29904935): Remove the large constant pool. - const int kMaxInternalConstantSizeInBytes = 128; - if (external_constant_pool_ && - ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { - string global_name = tensorflow::strings::StrCat( - "constant_global_", external_global_constant_counter_++); - llvm::GlobalVariable* result_global = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/IrShapeType(literal.shape()), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::ExternalLinkage, - /*Initializer=*/nullptr, - /*Name=*/AsStringRef(global_name)); - result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); - external_constant_pool_->Insert(global_name, literal, - MinimumAlignmentForShape(literal.shape())); - result = result_global; - } else { - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* result_global = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/initializer->getType(), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/initializer, - /*Name=*/""); - result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); - result = llvm::ConstantExpr::getBitCast( - result_global, IrShapeType(literal.shape())->getPointerTo()); - } - return result; + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + return llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); } Status IrEmitter::HandleConstant(HloInstruction* constant) { @@ -321,30 +294,42 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { return DefaultAction(select); } -Status IrEmitter::HandleInfeed(HloInstruction* infeed) { +Status IrEmitter::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast(instruction); VLOG(2) << "HandleInfeed: " << infeed->ToString(); - const Shape& shape = infeed->shape(); - - // The infeed operation produces data (dequeued from the infeed queue) at this - // address, which has been provided by buffer assignment. + // The infeed operation produces a two-element tuple containing data and a + // token value. HloInfeedInstruction::infeed_shape gives us the data shape. + const Shape& data_shape = infeed->infeed_shape(); + DCHECK(ShapeUtil::Equal(data_shape, + ShapeUtil::GetTupleElementShape(infeed->shape(), 0))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); - llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); - if (ShapeUtil::IsTuple(shape)) { - TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); + // Write the tuple index table. + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(infeed, {0})); + llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, + assignment_.GetUniqueSlice(infeed, {1})); + llvm::Value* token_address = EmitTempBufferPointer( + token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); + llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, + &ir_builder_, module_); + + if (ShapeUtil::IsTuple(data_shape)) { + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // For a tuple, we first copy each of the internal elements to // their corresponding target locations. We then construct the // tuple outer buffer containing pointers to the internal // elements. std::vector tuple_element_addresses; - for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, - assignment_.GetUniqueSlice(infeed, {i})); + assignment_.GetUniqueSlice(infeed, {0, i})); const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); + ShapeUtil::GetTupleElementShape(data_shape, i); // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root @@ -359,11 +344,11 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_, - module_); + llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), + tuple_element_addresses, &ir_builder_, module_); } else { - TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, - GetEmittedValueFor(infeed))); + TF_RETURN_IF_ERROR( + EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); } return Status::OK(); @@ -563,7 +548,8 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index input_index(index.size()); + llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), + index.size()); llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( @@ -694,7 +680,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm_ir::IrArray::Index operand_index(ir_builder_.getInt64Ty(), + source_index.size()); llvm::Value* in_bounds_condition = ir_builder_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( @@ -768,7 +755,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // value and the current output value. SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index selected_index; + llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( selected_index_address, {ir_builder_.getInt32(i)}); @@ -1110,7 +1097,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We are not in the padding, so carry out the computation. int num_dims = num_spatial_dims + 2; - llvm_ir::IrArray::Index input_index(num_dims); + llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims); for (int i = 0; i < num_spatial_dims; ++i) { input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } @@ -1118,7 +1105,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { input_index[dnums.input_batch_dimension()] = batch; llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); - llvm_ir::IrArray::Index kernel_index(num_dims); + llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(), + num_dims); for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() @@ -1429,6 +1417,10 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; + case HloOpcode::kXor: + return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + llvm::Value* rhs) { return ir_builder->CreateXor(lhs, rhs); }; + case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, @@ -1685,7 +1677,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( // } llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); - llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); + llvm_ir::IrArray::Index array_index(ir_builder_.getInt64Ty(), + reduce->shape().dimensions_size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); @@ -2069,7 +2062,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // Compute the output index the operand element should be assigned to. // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); - llvm_ir::IrArray::Index output_index; + llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = ir_builder_.CreateMul( operand_index[i], @@ -2531,7 +2524,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } -Status IrEmitter::HandleGenerateToken(HloInstruction* gen_token) { +Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); // No code to generate, but we need to emit an address for book-keeping. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index e1815c1db7a14dfc90ff646c0fd1e439ffffb2e8..3c110a320fad931e68e48236d4b4a33d0601ab5a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -30,7 +30,6 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -67,17 +66,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. - // external_constant_pool: if non-null, points to an ExternalConstantPool - // instance into which the Ir emitter can spill - // constants. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine, - ExternalConstantPool* external_constant_pool); + const TargetMachineFeatures* target_machine); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -150,7 +145,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; - Status HandleGenerateToken(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* gen_token) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -537,9 +532,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const TargetMachineFeatures& target_machine_features_; - int64 external_global_constant_counter_ = 0; - ExternalConstantPool* external_constant_pool_; - struct LiteralPtrHashFunctor { size_t operator()(const Literal* literal) const { return literal->Hash(); } }; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 54af40506dab48b3c2a3a44eb0b5f5fb213a32ec..59ae5acd8b7cea049f09eaf4cc98b41339973c77 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -31,13 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); + CHECK(!ShapeUtil::IsTuple(shape_)); CHECK(!ShapeUtil::IsScalar(shape_)); llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); const int64 num_dims = shape_.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); + llvm_ir::IrArray::Index array_index(index_type, num_dims); // Add loops from outer-most to inner-most dimensions. for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 755715634aa70a822b21d25dcae20a8fe053477a..25e182a26d6f21c7eba550020cf17403aa92abf7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fc2efbaf9a22b02cd729da2f367d53bc15506836..36c9f743859ae2da6c4fb3fd753bd7862fe2d3ab 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -110,8 +110,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - infeed0 = u32[12345678,2]{1,0} infeed() - ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed() + infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 + ROOT outfeed0 = token[] outfeed(infeed0.data) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 167aa4adda995a259190a932a76a34ca5883444c..7e792a82b8bf28121c054332bc619d736858c729 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -49,9 +49,9 @@ int main(int argc, char** argv) { // Build computation. xla::XlaBuilder builder(""); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p1, p0, {0}); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index c4c90515ac7ec2721cb9ea48d42e3c5080e249af..be772cfb7e564cebc5725854dbf5678e5c507556 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -127,13 +127,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, } llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { - if (const uint8* from_constant_pool = - external_constant_pool_.Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); if (func_addr == nullptr) { return nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 1851a3ee0bb97b4860605d7211a6ae70ac88686b..d74b63fcf45bd70cd18ee41f1e9714ba6a222abd 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -91,10 +90,6 @@ class SimpleOrcJIT { llvm::TargetMachine* target_machine() const { return target_machine_.get(); } - ExternalConstantPool* external_constant_pool() { - return &external_constant_pool_; - } - // Creates an llvm::TargetMachine suitable for JITting code that will run on // the current machine. static std::unique_ptr InferTargetMachineForJIT( @@ -112,7 +107,6 @@ class SimpleOrcJIT { std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; - ExternalConstantPool external_constant_pool_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index faac927027c48e44eb8ff1fcc4109fbc177fc579..1d4bf483aedef5a15ef51cf216030b76255d4ec8 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -56,7 +56,8 @@ class CpuExternalConstantsTest : public CpuCodegenTest { TEST_F(CpuExternalConstantsTest, Basic) { TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( -CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK-NOT: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK: @0 = private constant [4194304 x i8] {{.*}}, align 16 )"); } @@ -65,7 +66,7 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 -CHECK: @0 = private constant [16 x float] {{.*}}, align 8 +CHECK: @0 = private constant [64 x i8] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 23e7a3de4d8188a3add259582e11030539e154c1..783b2820e922612973632c555fc8ae01418f1754 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -96,8 +96,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + vshape, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); @@ -114,9 +117,9 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction->fused_expression_root()->opcode()); - // There should be 7 fused instructions: 2 parameters and the fused + // There should be 8 fused instructions: 2 parameters and the fused // operations. - EXPECT_EQ(7, fusion_instruction->fused_instruction_count()); + EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. auto result = ExecuteAndTransfer(std::move(module), {}); @@ -170,8 +173,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + cshape, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); @@ -188,9 +194,9 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction1->fused_expression_root()->opcode()); - // There should be 5 fused instructions in the root fusion instruction: 2 + // There should be 6 fused instructions in the root fusion instruction: 2 // parameters, multiply, floor, and exp. - EXPECT_EQ(5, fusion_instruction1->fused_instruction_count()) + EXPECT_EQ(6, fusion_instruction1->fused_instruction_count()) << fusion_instruction1->fused_instructions_computation()->ToString(); auto fusion_instruction2 = reduce->operand(0); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index dd63b998e9b6d04981ec6f7300c883c9b23b154f..ea7e479d66fbda1bfd388fd77b25db2db56f0d65 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -47,7 +47,7 @@ class InfeedTest : public ClientLibraryTestBase { // don't use ResetDevice since it is not implemented on CPU. ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); - builder.Infeed(literal.shape()); + Infeed(&builder, literal.shape()); if (ShapeUtil::IsTuple(literal.shape())) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); @@ -125,8 +125,8 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(40.0f), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 40.0f), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add the reduced value of the Infeed @@ -134,17 +134,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto infeed = builder.Infeed(infeed_shape); - auto addend = - builder.Reduce(infeed, builder.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder), {0}); - builder.Add(prev, addend); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto infeed = Infeed(&builder, infeed_shape); + auto addend = Reduce(infeed, ConstantR0(&builder, 0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + Add(prev, addend); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - auto init = builder.ConstantR0(0.0f); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0.0f); + While(condition, body, init); // Build and asynchronously launch the computation. auto computation = builder.Build().ConsumeValueOrDie(); @@ -207,8 +206,8 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.GetTupleElement(prev, 1); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + GetTupleElement(prev, 1); condition = builder.Build().ConsumeValueOrDie(); } @@ -221,27 +220,27 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { const auto build_body = [this, &result_shape](const Shape& infeed_shape) { XlaComputation body; XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto infeed = builder.Infeed(infeed_shape); - auto addend = builder.Reduce( - builder.GetTupleElement(infeed, 0), builder.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder), {0}); - auto result = builder.Add(builder.GetTupleElement(prev, 0), addend); - builder.Tuple({result, builder.GetTupleElement(infeed, 1)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto infeed = Infeed(&builder, infeed_shape); + auto addend = + Reduce(GetTupleElement(infeed, 0), ConstantR0(&builder, 0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + auto result = Add(GetTupleElement(prev, 0), addend); + Tuple(&builder, {result, GetTupleElement(infeed, 1)}); return builder.Build().ConsumeValueOrDie(); }; // Create the first while loop with infeed1_shape. - auto init = builder.Tuple( - {builder.ConstantR0(0.0f), builder.ConstantR0(true)}); - auto while1 = builder.While(condition, build_body(infeed1_shape), init); - auto result1 = builder.Tuple( - {builder.GetTupleElement(while1, 0), builder.ConstantR0(true)}); + auto init = Tuple(&builder, {ConstantR0(&builder, 0.0f), + ConstantR0(&builder, true)}); + auto while1 = While(condition, build_body(infeed1_shape), init); + auto result1 = Tuple( + &builder, {GetTupleElement(while1, 0), ConstantR0(&builder, true)}); // Create the second while loop with infeed2_shape. Note that the result from // the first while loop is used as the initial value. - auto while2 = builder.While(condition, build_body(infeed2_shape), result1); - builder.GetTupleElement(while2, 0); + auto while2 = While(condition, build_body(infeed2_shape), result1); + GetTupleElement(while2, 0); // Build the computation. auto computation = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 27044b1d62027e3b83744c486cb790269e505aff..90b99c828e2fcfd77579026a39d3a6711599feee 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -38,7 +38,8 @@ while_body { while_cond { arg_cond = f32[2,3,2] parameter(0) - ROOT unknown = pred[] infeed() + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { @@ -49,14 +50,14 @@ ENTRY main { {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - out0 = () outfeed(f32[2,3,2] const_a) - ROOT out1 = () outfeed(f32[2,3,2] const_b) + out0 = token[] outfeed(f32[2,3,2] const_a) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b) } )"; string filecheck_pattern = R"( -CHECK: private constant [12 x float] -CHECK-NOT: private constant [12 x float] +CHECK: private constant [48 x i8] +CHECK-NOT: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -84,7 +85,8 @@ while_body { while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT unknown = pred[] infeed() + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { @@ -98,10 +100,10 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [1 x float] -CHECK: private constant [2 x float] -CHECK-NOT: private constant [1 x float] -CHECK-NOT: private constant [2 x float] +CHECK: private constant [4 x i8] +CHECK: private constant [8 x i8] +CHECK-NOT: private constant [4 x i8] +CHECK-NOT: private constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 1ee279290b6fcfe775ce9867d424b1c031f5d2bd..dac416e1c78c2f60d458480c5062f48b77d4878d 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -32,12 +32,13 @@ ENTRY main { {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - ROOT out = () outfeed(f32[2,3,2] const_a) + outfeed = token[] outfeed(f32[2,3,2] const_a) + ROOT root = () tuple() } )"; string filecheck_pattern = R"( -CHECK: private constant [12 x float] +CHECK: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index ee2b455730f8f520db6652f0352f8a96291cac73..cb3676c5ba9b55ef4cb46dbd97f84ea9a6a6c5d0 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -183,6 +183,9 @@ class DfsHloVisitorBase { virtual Status HandleOr(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } + virtual Status HandleXor(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } virtual Status HandleShiftLeft(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -243,7 +246,7 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; - virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + virtual Status HandleAfterAll(HloInstructionPtr token) = 0; // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6934e00a4b665e9e6a4302e0c0a8ce1d5bb94373..987c91e5ba3eb01a7535d162cbcf6441d568adae 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,7 +188,7 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } - Status HandleGenerateToken(HloInstructionPtr token) override { + Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 93fea7ead7a86bb34c449668fd88a58145681eb1..ce0951bbe1873973c7b97055aba5ba71a14ad24f 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1164,6 +1164,8 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kXor: + return ir_builder_->CreateXor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1220,7 +1222,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const Shape& operand_shape = hlo.operand(operand_no)->shape(); // If the operand is scalar, the source index is always {}. if (ShapeUtil::IsScalar(operand_shape)) { - return llvm_ir::IrArray::Index(); + return llvm_ir::IrArray::Index(target_index.GetType()); } // If no implicit broadcast is needed for this operand, returns the target @@ -1232,13 +1234,13 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); - llvm_ir::IrArray::Index source_index; + llvm_ir::IrArray::Index source_index(target_index.GetType()); for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { CHECK_EQ(1, operand_shape.dimensions(i)); - source_index.push_back(ir_builder_->getInt64(0)); + source_index.push_back(target_index.GetConstantWithIndexType(0)); } } return source_index; @@ -1540,9 +1542,14 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); + // Use the same index type for all tensor accesses in the same kernel. + llvm::Type* index_type = index.GetType(); + llvm_ir::IrArray::Index slice_start_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(hlo->operand(1))(dim_index)); @@ -1552,17 +1559,17 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide // to oficially document different behavior. - start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, - index[i]->getType()); - llvm::Value* operand_dim_size = llvm::ConstantInt::get( - start_index_value->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* output_dim_size = llvm::ConstantInt::get( - start_index_value->getType(), hlo->shape().dimensions(i)); + start_index_value = + ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); + llvm::Value* operand_dim_size = + index_typed_const(input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = + index_typed_const(hlo->shape().dimensions(i)); start_index_value = EmitIntegralMin( ir_builder_->CreateSub(operand_dim_size, output_dim_size), - EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), - start_index_value, /*is_signed=*/true), + EmitIntegralMax(index_typed_const(0), start_index_value, + /*is_signed=*/true), /*is_signed=*/true); start_index_value->setName( @@ -1570,7 +1577,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( slice_start_index[i] = start_index_value; } - llvm_ir::IrArray::Index input_index(rank); + llvm_ir::IrArray::Index input_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index @@ -1594,17 +1601,18 @@ StatusOr ElementalIrEmitter::EmitElementalGather( const llvm_ir::ElementGenerator& indices_generator = operand_to_generator.at(hlo->operand(1)); + llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to // generate. This index "unsafe" as in the components in here may be // out of bounds. - IrArray::Index unsafe_operand_index; + IrArray::Index unsafe_operand_index(index_type); // First copy in the window indices to unsafe_operand_index. for (int64 i = 0, e = operand_shape.dimensions_size(), unsafe_operand_index_dim = 0; i < e; i++) { if (c_binary_search(dim_numbers.elided_window_dims(), i)) { - unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + unsafe_operand_index.push_back(index.GetConstantWithIndexType(0)); } else { unsafe_operand_index.push_back( index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); @@ -1612,7 +1620,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( } // This is the index of the index vector in the gather_indices tensor. - IrArray::Index gather_index_index; + IrArray::Index gather_index_index(index_type); { std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { @@ -1628,8 +1636,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, int64 dim) { - llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( - index_component, ir_builder_->getInt64Ty()); + llvm::Value* gather_dim_component_extended = + ir_builder_->CreateSExtOrTrunc(index_component, index_type); unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = ir_builder_->CreateAdd( unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], @@ -1645,18 +1653,18 @@ StatusOr ElementalIrEmitter::EmitElementalGather( indices_shape.dimensions(dim_numbers.index_vector_dim()); for (int64 i = 0; i < index_vector_size; i++) { gather_index_index[dim_numbers.index_vector_dim()] = - ir_builder_->getInt64(i); + index.GetConstantWithIndexType(i); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); add_to_unsafe_operand_index(gather_dim_component, i); } } - IrArray::Index safe_operand_index; + IrArray::Index safe_operand_index(index_type); for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { safe_operand_index.push_back(ir_builder_->CreateURem( unsafe_operand_index[i], - ir_builder_->getInt64(operand_shape.dimensions(i)))); + index.GetConstantWithIndexType(operand_shape.dimensions(i)))); } return operand_generator(safe_operand_index); @@ -1671,14 +1679,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); + llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); + llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' llvm::Value* slice_intersection = ir_builder_->getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + llvm::Type* index_type = index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(start_hlo)(dim_index)); @@ -1688,18 +1700,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide // to oficially document different behavior. - start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, - index[i]->getType()); - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - - start_index_value = EmitIntegralMin( - ir_builder_->CreateSub(input_dim_size, update_dim_size), - EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), - start_index_value, /*is_signed=*/true), - /*is_signed=*/true); + start_index_value = + ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); + llvm::Value* input_dim_size = + index_typed_const(input_hlo->shape().dimensions(i)); + llvm::Value* update_dim_size = + index_typed_const(update_hlo->shape().dimensions(i)); + + start_index_value = + EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(index_typed_const(0), start_index_value, + /*is_signed=*/true), + /*is_signed=*/true); start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); @@ -1729,7 +1741,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Handle true BB (return data from 'update') SetToFirstInsertPoint(if_data.true_block, ir_builder_); // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); + llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); } @@ -1797,7 +1809,8 @@ StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.false_block, ir_builder_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); ir_builder_->CreateStore(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, ir_builder_); @@ -1824,10 +1837,15 @@ StatusOr ElementalIrEmitter::EmitElementalDot( int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); - std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), - ir_builder_); + llvm::Type* index_type = dot_result_index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + + std::unique_ptr inner_loop = + llvm_ir::ForLoop::EmitForLoop(IrName(hlo, "inner"), index_typed_const(0), + index_typed_const(contracted_dim_size), + index_typed_const(1), ir_builder_); SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); PrimitiveType primitive_type = hlo->shape().element_type(); @@ -1846,7 +1864,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( // Given an output index [a,b,c,d,e] in the result, we compute: // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - IrArray::Index lhs_index, rhs_index; + IrArray::Index lhs_index(index_type), rhs_index(index_type); for (int64 i = 0; i < lhs_dims - 1; i++) { lhs_index.push_back(dot_result_index[i]); @@ -1945,6 +1963,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kShiftLeft: diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 6df172db8e541c5cef7aab04f3d8611fc735e8b0..fd75847d0c0e737957401b8efc420d504a3c0706 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -82,7 +82,18 @@ StatusOr Executable::ExecuteOnStreamWrapper( StatusOr return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); - TF_RETURN_IF_ERROR(return_value.status()); + if (!return_value.status().ok()) { + if (profile != nullptr) { + // Ensure the ThenStartTimer call has completed before we destroy timer. + // We already have a failure status to return, so just log this if it + // fails. + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + LOG(ERROR) << "Failed to BlockHostUntilDone: " << status; + } + } + return return_value.status(); + } if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; @@ -116,6 +127,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( if (profile->compute_time_ns() == 0) { profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); } + + const int64 executable_size_in_bytes = SizeInBytes(); + if (executable_size_in_bytes != 0) { + profile->set_executable_size_in_bytes(executable_size_in_bytes); + } } if (profile_ptr != nullptr) { @@ -129,6 +145,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( return return_value; } +int64 Executable::SizeInBytes() { return -1; } + Status Executable::DumpHloSnapshot() { TF_RET_CHECK(dumping_snapshot()); TF_RET_CHECK(hlo_snapshot_->has_hlo() && diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index dc1f26ea65cc707d4f0522af2aa3ec40621632f1..98eaeee30a693211ae564a5ef3c373f0364bef11 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -88,8 +88,7 @@ class Executable { // called explicitly for other (async, for example) variants after the stream // has completed. virtual Status PopulateExecutionProfile( - HloExecutionProfile* hlo_execution_profile, - se::StreamExecutor* executor) { + HloExecutionProfile* hlo_execution_profile, se::Stream* stream) { return Status::OK(); } @@ -132,10 +131,14 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& host_result_shape() const { - return hlo_module_->config().host_entry_computation_layout().result_shape(); + const Shape& result_shape() const { + return hlo_module_->config().entry_computation_layout().result_shape(); } + // Returns the size of the executable in bytes. Returns -1 by default if the + // method is not overridden to support this kind of query. + virtual int64 SizeInBytes(); + // Dumping helpers. void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { hlo_snapshot_ = std::move(hlo_snapshot); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index d9f62c21c4ef932bb61f2f9e0f7a318366ce94f0..85e28a0dfe38415974e435106a2d0b75863f2df5 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -43,7 +43,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -52,12 +52,24 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( for (const se::DeviceMemoryBase& element : elements) { element_pointers.push_back(element.opaque()); } - return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), - element_pointers.data(), region); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, GetByteSizeRequirement(shape), element_pointers.data(), region)); + // Ensure the buffer is transferred before we destroy element_pointers. + return stream->BlockHostUntilDone(); +} + +void GenericTransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) { + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + return done(status); + } + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); } StatusOr> -GenericTransferManager::TransferLiteralFromDevice( +GenericTransferManager::TransferLiteralFromDeviceInternal( se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; @@ -75,8 +87,7 @@ GenericTransferManager::TransferLiteralFromDevice( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { if (ShapeUtil::IsArray(subshape)) { - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ @@ -88,8 +99,8 @@ GenericTransferManager::TransferLiteralFromDevice( return std::move(literal); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, +Status GenericTransferManager::TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -103,9 +114,10 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK( ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + TF_RET_CHECK(stream->parent()->device_ordinal() == + device_buffer.device_ordinal()); - TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); + TF_RETURN_IF_ERROR(WriteTupleIndexTables(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), @@ -121,16 +133,21 @@ Status GenericTransferManager::TransferLiteralToDevice( if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { source = subliteral.untyped_data(); + return TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory); } else { // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); source = relayed_out_literal->untyped_data(); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory)); + return stream->BlockHostUntilDone(); } - return TransferBufferToDevice( - executor, - /*size=*/GetByteSizeRequirement(device_subshape), source, - &device_memory); } return Status::OK(); }); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3da9570ef7eebcdf618439f628fb4d5589993e4f..d216fe7d29e8f2e84ab4f558ee5caec32d07a70a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -41,12 +41,13 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; + void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) override; - Status TransferLiteralToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, - const ShapedBuffer& device_buffer) override; + Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; @@ -64,11 +65,14 @@ class GenericTransferManager : public TransferManager { const void* source) override; Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: + StatusOr> TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 541a5275a384ebfd900be086216c6d0c6958cd88..88f994786a50b2516df845602af796eb12baf579 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -442,6 +442,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", @@ -583,7 +584,6 @@ cc_library( "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", - "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -613,7 +613,6 @@ cc_library( "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", - "@llvm//:support", ], alwayslink = True, # Contains compiler registration ) @@ -771,6 +770,7 @@ cc_library( hdrs = ["stream_executor_util.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index db6924c742e4a949a3e939b6d6659e92c2d1e312..c77e3c81c9d38af7857ad1389d20221514bf38f1 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -126,12 +126,17 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { HloInstruction* variance_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-2))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + inverse_stddev->shape(), + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-2))), + {})))); HloInstruction* variance = computation_->AddInstruction(HloInstruction::CreateBinary( variance_plus_epsilon->shape(), HloOpcode::kSubtract, - variance_plus_epsilon, epsilon)); + variance_plus_epsilon, + computation_->AddInstruction(HloInstruction::CreateBroadcast( + variance_plus_epsilon->shape(), epsilon, {})))); // Repackage the results. std::unique_ptr new_tuple = HloInstruction::CreateTuple({ @@ -175,12 +180,17 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { HloInstruction* var_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( batch_norm->operand(3)->shape(), HloOpcode::kAdd, - batch_norm->mutable_operand(3), epsilon)); + batch_norm->mutable_operand(3), + computation_->AddInstruction(HloInstruction::CreateBroadcast( + batch_norm->operand(3)->shape(), epsilon, {})))); HloInstruction* inverse_stddev = computation_->AddInstruction(HloInstruction::CreateBinary( var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-.5))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + var_plus_epsilon->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(-.5))), + {})))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index b812dd7d3fbb25f279e87f79b647e299f29073ea..27d2c3e491bfc2108cbd168d1a5e1575c2eed11f 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -376,11 +376,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( "reduce_window_accum_ptr", ir_builder_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); ir_builder_->CreateStore(init_value, accum_ptr); } - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm::Type* index_type = index.GetType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return index.GetConstantWithIndexType(c); + }; + + llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -391,14 +397,14 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); - IrArray::Index input_index(index.size()); + IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = ir_builder_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = ir_builder_->CreateNSWMul( - index[i], ir_builder_->getInt64(window.dimensions(i).stride())); + index[i], index_typed_const(window.dimensions(i).stride())); input_index[i] = ir_builder_->CreateNSWSub( ir_builder_->CreateNSWAdd(stridden_index, window_index[i]), - ir_builder_->getInt64(window.dimensions(i).padding_low())); + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This @@ -409,7 +415,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( in_bounds, ir_builder_->CreateICmpULT( input_index[i], - ir_builder_->getInt64(operand->shape().dimensions(i)))); + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -435,11 +441,13 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( llvm::Value* accum_ptr = ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( hlo->shape().element_type(), module_)); + llvm::Type* index_type = output_index.GetType(); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index_type))); ir_builder()->CreateStore(init_value, accum_ptr); - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 9d66648a402fb82c35e0bf3ea1179f7995ed7c76..decfc40dafafe875fa02bab6695f5c54e522f267 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" @@ -165,9 +164,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - // Rewrite gather ops into smaller ones. - pass.AddPass(); - // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); @@ -209,7 +205,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), stream_exec); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index c5ccdd4a7dcec02ddab8a1f748659de41f6202d2..fbc1303085b579e898d2f503a341754109768567 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -52,60 +52,20 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { HloDataflowAnalysis::Run(*module)); // Make sure all operands of a library call are in memory instead of constants - // in IR. - for (HloInstruction* hlo : - module->entry_computation()->MakeInstructionPostOrder()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); - const auto& values = dataflow->GetValueSet(operand).values(); - if (std::any_of(values.begin(), values.end(), [](const HloValue* value) { - return value->defining_instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy)); - changed = true; - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo) || - hlo->opcode() == HloOpcode::kCrossReplicaSum) { - // For all other library calls and cross-replica-sum, materialize all the - // operands into memory. (Cross-replica-sum gets its constant args - // materialized even if it's not implemented as a libcall to simplify the - // implementation. It's slower, but we can constant fold away constant - // args *anyway*, so we just need to make it work.) - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } - } - - // Init values of while and conditional nodes cannot be constants. Insert - // copies for any constants found at the operands of these nodes. + // in IR. Also, init values of while and conditional nodes cannot be + // constants. Insert copies for any constants found at the operands of these + // nodes. tensorflow::gtl::FlatSet inserted_copies; for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kConditional) { - continue; - } - for (auto operand : instruction->operands()) { + for (HloInstruction* hlo : computation->instructions()) { + // Inserts a copy of hlo->operand(n) if it's a constant. + auto copy_operand_if_constant = [&](int64 n) -> Status { + HloInstruction* operand = hlo->mutable_operand(n); // Skip the operands that have already been replaced with a copy in a // previous iteration (which is possible when a constant is used as an // operand in multiple places). if (ContainsKey(inserted_copies, operand)) { - continue; + return Status::OK(); } for (auto& pair : dataflow->GetInstructionValueSet(operand)) { const HloValueSet& value_set = pair.second; @@ -121,6 +81,47 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { } } } + return Status::OK(); + }; + + if (IsCustomCallToDnnBatchNorm(*hlo)) { + // The epsilon and feature_index operands to a CUDNN batchnorm op don't + // need to be materialized in memory -- in fact, they must be constants. + // These are the last two operands of all three batchnorm ops. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } else if (ImplementedAsLibraryCall(*hlo) || + hlo->opcode() == HloOpcode::kCrossReplicaSum || + hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kConditional) { + // For all other library calls, cross-replica-sum, while and conditional + // ops materialize all the operands into memory. (Cross-replica-sum + // gets its constant args materialized even if it's not implemented as a + // libcall to simplify the implementation. It's slower, but we can + // constant fold away constant args *anyway*, so we just need to make it + // work.) + for (int64 i = 0; i < hlo->operand_count(); ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } + } + } + + if (changed) { + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + if (!IsCustomCallToDnnBatchNorm(*hlo)) { + continue; + } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); + ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); + } } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 7bb8df6581b49b1bf8c84a972f715e8dc119d8de..5343497c03c13a2589363da0fa33e18520220826 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -55,33 +55,28 @@ Status GpuTransferManager::TransferLiteralToInfeed( return TransferBufferToInfeed(executor, size, literal.untyped_data()); } - if (ShapeUtil::IsNestedTuple(shape)) { - return Unimplemented( - "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); - } - // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. std::vector buffers; - buffers.reserve(ShapeUtil::TupleElementCount(shape)); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } }); - for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); - int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); - TF_ASSIGN_OR_RETURN( - gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, tuple_element_size, - literal.untyped_data({i}))); - buffers.push_back(buffer); - } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(literal_subshape)) { + int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); + TF_ASSIGN_OR_RETURN( + gpu::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, tuple_element_size, + literal.untyped_data(index))); + buffers.push_back(buffer); + } + return Status::OK(); + })); cleanup.release(); return EnqueueBuffersToInfeed(executor, buffers); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index e303999c63ff699487bc2362850459ab691f6bc8..d420863b8569771b16a03591b6a0ddd0591f7e2e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -137,7 +137,7 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType( ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); @@ -158,7 +158,7 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index) { + ShapeIndexView shape_index) { VLOG(2) << "Binding " << hlo.ToString(); const Shape& hlo_shape = hlo.shape(); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 3d34311b4368d17cb074aaf33c71fc865e96387e..a86e6e78c693ac53bb2c70d88b999a4e1273ecad 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -51,7 +51,7 @@ class HloToIrBindings { // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index = {}); + ShapeIndexView shape_index = {}); // Unbinds all IR values that's defined in an LLVM function, e.g., function // arguments and stack variables. Global variables will be kept in bindings_. @@ -71,7 +71,7 @@ class HloToIrBindings { // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, - const ShapeIndex& shape_index = {}) const { + ShapeIndexView shape_index = {}) const { auto it = base_ptrs_.find(&hlo); CHECK(it != base_ptrs_.end()) << hlo.ToString(); return it->second.element(shape_index); @@ -97,7 +97,7 @@ class HloToIrBindings { // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape. llvm::Value* GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value); const BufferAssignment* buffer_assignment_; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index ea34d5b30c91e8b809e3e17a904e27e589fd6b5f..2b63d8727cb11f4369b17adb87bcba18ed2b8b65 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -22,29 +22,29 @@ namespace xla { namespace gpu { InfeedThunk::InfeedThunk( - tensorflow::gtl::ArraySlice tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + const ShapeTree& infeed_slices, const HloInstruction* hlo_instruction) - : Thunk(Kind::kInfeed, hlo_instruction), - tuple_element_buffers_(tuple_element_buffers.begin(), - tuple_element_buffers.end()), - destination_buffer_(destination_buffer) {} + : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) { VLOG(2) << "Infeeding to GPU "; - se::DeviceMemoryBase destination_address = - buffer_allocations.GetDeviceAddress(destination_buffer_); - + // First copy the infeed data which is element 0 of the infeed instruction's + // two-tuple output (the other element is a token). + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); std::vector infeed_buffers; - if (ShapeUtil::IsTuple(hlo_instruction()->shape())) { - CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape())); + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0); + if (ShapeUtil::IsTuple(data_shape)) { + CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // Transfer the tuple elements first. std::vector tuple_element_addresses; - for (BufferAllocation::Slice tuple_element_buffer : - tuple_element_buffers_) { + for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) { + const BufferAllocation::Slice& tuple_element_buffer = + infeed_slices_.element({0, i}); se::DeviceMemoryBase tuple_element_address = buffer_allocations.GetDeviceAddress(tuple_element_buffer); @@ -56,15 +56,23 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, } // Transfer the tuple outer buffer. auto host_size = tuple_element_addresses.size() * sizeof(void*); - stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(), + stream->ThenMemcpy(&data_address, tuple_element_addresses.data(), host_size); } else { InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&destination_address, *(buffer->device_memory()), + stream->ThenMemcpy(&data_address, *(buffer->device_memory()), buffer->length()); } + // Construct top-level tuple of infeed containing the data and the token. Use + // a nullptr for the token, it should never be dereferenced. + std::vector infeed_addresses = {data_address.opaque(), nullptr}; + se::DeviceMemoryBase top_level_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({})); + stream->ThenMemcpy(&top_level_address, infeed_addresses.data(), + 2 * sizeof(void*)); + Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 93713cb12defd95bdd69cb0aa7ad7b4e37fc8fae..cb9a6232f3bcdcbf37bc195069bac449a7217401 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -32,12 +32,8 @@ namespace gpu { class InfeedThunk : public Thunk { public: // Constructs a InfeedThunk that copies data from the on-device - // infeed queue to the device buffer - // `destination_buffer`. `mem_size` is the size of the data in - // bytes. - InfeedThunk(tensorflow::gtl::ArraySlice - tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + // infeed queue into the buffers in the given shape tree. + InfeedThunk(const ShapeTree& infeed_slices, const HloInstruction* hlo_instruction); InfeedThunk(const InfeedThunk&) = delete; @@ -47,8 +43,7 @@ class InfeedThunk : public Thunk { se::Stream* stream) override; private: - const std::vector tuple_element_buffers_; - const BufferAllocation::Slice destination_buffer_; + const ShapeTree infeed_slices_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 6c4519185b34989eb53c884ba214d69b824b113c..64ed3d748febd8281a8e602194b31c937a4a682a 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -40,6 +40,7 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 7b7dd673a5c35e586105f1a6253c72c3aa0b0151..d5e07c3afb7dcb7e7a848b8c02e413c21d8ea155 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -94,10 +94,7 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { << std::endl << " its type: " << llvm_ir::DumpToString(*global_for_const->getType()); - llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( - global_for_const, - llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); - bindings_.BindHloToIrValue(*constant, shape_constant); + bindings_.BindHloToIrValue(*constant, global_for_const); return Status::OK(); } @@ -194,6 +191,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( HloOpcode root_opcode = computation.root_instruction()->opcode(); PrimitiveType element_type = computation.root_instruction()->shape().element_type(); + bool is_atomic_integral = element_type == S32 || element_type == U32 || + element_type == S64 || element_type == U64; llvm::Value* source = ir_builder_.CreateLoad(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. @@ -204,7 +203,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( {output_address->getType()}, &ir_builder_); return true; } - if (primitive_util::IsIntegralType(element_type)) { + if (is_atomic_integral) { // integral + integral ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, @@ -213,9 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } } - // NVPTX supports atomicMax and atomicMin on only integer types. - if (root_opcode == HloOpcode::kMaximum && - primitive_util::IsIntegralType(element_type)) { + // NVPTX supports atomicMax and atomicMin only on integer types. + if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) { // max(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max @@ -225,8 +223,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( return true; } - if (root_opcode == HloOpcode::kMinimum && - primitive_util::IsIntegralType(element_type)) { + if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) { // min(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min @@ -478,12 +475,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_type = ir_builder_.getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { // If the operands are scalar, don't emit any loops. llvm::Value* lhs_value = - lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + lhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); llvm::Value* rhs_value = - rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + rhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); @@ -493,7 +493,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { } else { result = ir_builder_.CreateFMul(lhs_value, rhs_value); } - target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_); + target_array.EmitWriteArrayElement(/*index=*/element_index, result, + &ir_builder_); return Status::OK(); } @@ -584,7 +585,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // address. The index into the target address is the concatenation of the rhs // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(index_type); for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index bb47a4280541ce2806472aa9365bb0ef38c0c3b3..c9574c87a3be208915b3d6a32679553eb425d2f0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -120,9 +120,10 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { + const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); std::vector target_arrays; - for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; - ++i) { + target_arrays.reserve(num_elems); + for (int64 i = 0; i != num_elems; ++i) { target_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR( @@ -130,6 +131,7 @@ Status IrEmitterNested::EmitTargetElementLoop( .EmitLoop()); std::vector tuple_operand_ptrs; + tuple_operand_ptrs.reserve(num_elems); for (const llvm_ir::IrArray& array : target_arrays) { tuple_operand_ptrs.push_back(array.GetBasePointer()); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ccbd99a0420ae8d5183fa112468b3f7cc678503e..bdb9e77da4d4fda23cad128fc6400a1205e7d54b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -283,6 +283,69 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Cannot unroll. return 1; } + +// Returns the llvm type for the indices used in the kernel that contains the +// hlo instruction. Such indices include the index for the parallel loop and +// the indices for the tensors accessed by the kernel. The return type is i32 +// iff the following conditions are met: +// . The launch_size of the kernel is within the range of i32. +// . The sizes of all the tensors accessed within the kernel are within the +// range of i32. +// Otherwise, the return type is i64. +llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, + llvm::IRBuilder<>* ir_builder) { + // Find the unnested hlo instructon for which the kernel is generated for. + const HloInstruction* unnested_hlo = hlo; + const HloComputation* computation = hlo->parent(); + if (computation->IsFusionComputation()) { + unnested_hlo = computation->FusionInstruction(); + } + + auto shape_in_range = [&](const Shape& s) { + bool in_range = true; + ShapeUtil::ForEachSubshape( + s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(sub_shape) && + !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); + + return in_range; + }; + + llvm::Type* i64_ty = ir_builder->getInt64Ty(); + // Check launch dimension + if (!IsInt32(launch_size)) { + return i64_ty; + } + + // Check the size of result tensors + if (!shape_in_range(unnested_hlo->shape())) { + return i64_ty; + } + + auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool { + return shape_in_range(operand->shape()); + }; + + // Check the size of input tensors + if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + return i64_ty; + } + + // Check the size of the internal result tensors + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + if (!c_all_of( + unnested_hlo->fused_instructions_computation()->instructions(), + hlo_shape_in_range)) { + return i64_ty; + } + } + + return ir_builder->getInt32Ty(); +} + } // namespace Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { @@ -551,17 +614,16 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { if (root->opcode() == HloOpcode::kTuple) { output_shape_index = {i}; } - // TODO(kramerb): CHECK that layouts are equal. Currently this - // breaks multioutputfusion_test. The test has pre-fused - // instructions, but layout_assignment will not assign any layouts - // for instructions inside of a fused computation. It just removes - // the layouts instead. if (inst->opcode() == HloOpcode::kReduce) { - CHECK(ShapeUtil::Compatible(first_reduce->shape(), inst->shape())); - CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - CHECK(ShapeUtil::Compatible(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); + CHECK(IsReductionToVector(*inst)) + << "Only reductions to vector are supported"; + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); CHECK(first_reduce->dimensions() == inst->dimensions()); input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); init_value_gens.push_back( @@ -569,8 +631,13 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { reducers.push_back(inst->to_apply()); reduce_output_shapes.push_back(std::move(output_shape_index)); } else { - CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), - inst->shape())); + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), std::move(output_shape_index)); } @@ -1002,6 +1069,20 @@ Status IrEmitterUnnested::EmitReductionToScalar( int64 num_tiles = RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {num_tiles}, {0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + llvm::Type* index_ty = GetIndexTypeForKernel( + reduce, + launch_dimensions.block_count() * launch_dimensions.threads_per_block(), + &ir_builder_); + + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Check whether every thread will process a full tile's worth of elements // without reading outside the bounds of the input. If this is true, we can // skip some bounds checks in the final algorithm. @@ -1050,40 +1131,42 @@ Status IrEmitterUnnested::EmitReductionToScalar( llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; + x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_const(0), + index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &ir_builder_); llvm::Value* x = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)), + ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)), tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)), + ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)), "x_in_bounds", &ir_builder_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); } + llvm_ir::IrArray::Index input_index( /*linear=*/x, input_shape, &ir_builder_); llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); @@ -1102,12 +1185,12 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. llvm::Value* x_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize))); + index_typed_const(kTileSize), + ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)), + ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)), ir_builder_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); @@ -1158,9 +1241,9 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = ir_builder_.CreateURem( - x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id"); + x_in_tiles, index_typed_const(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); @@ -1182,10 +1265,6 @@ Status IrEmitterUnnested::EmitReductionToScalar( }; // Emit a parallel loop that iterates through all input tiles, one per thread. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1193,7 +1272,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); } Status IrEmitterUnnested::EmitColumnReduction( @@ -1224,6 +1303,17 @@ Status IrEmitterUnnested::EmitColumnReduction( // If the height is not a multiple of the tile size, we pad the bottom of the // input matrix. const int64 height_in_tiles = CeilOfRatio(height, kTileSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_ty = ir_builder_.getInt64Ty(); + + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < height_in_tiles * width; @@ -1259,8 +1349,9 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); @@ -1271,24 +1362,27 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x = tile_index[1]; + y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty); + x = ir_builder_.CreateZExtOrTrunc(x, index_ty); + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_const(0), + index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &ir_builder_); llvm::Value* y = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)), + ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)), tile_element_loop->GetIndVarValue()); + // Unless we know the tile is entirely in bounds, we have to emit a // y-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)), + ir_builder_.CreateICmpULT(y, index_typed_const(height)), "y_in_bounds", &ir_builder_); // Emit code that reads the input element and accumulates it to @@ -1338,10 +1432,10 @@ Status IrEmitterUnnested::EmitColumnReduction( // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's // immediately beyond the tile. llvm::Value* y_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize))); + index_typed_const(kTileSize), + ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize))); llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)), + ir_builder_.CreateICmpULE(y_end, index_typed_const(height)), ir_builder_.getInt1(height % kTileSize == 0)); // The tile is entirely in bound if "height" is a multiple of kTileSize or // y_end <= height. @@ -1378,10 +1472,6 @@ Status IrEmitterUnnested::EmitColumnReduction( }; // Emit a parallel loop that iterate through all input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1389,7 +1479,7 @@ Status IrEmitterUnnested::EmitColumnReduction( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); } static std::pair ComputeTilingSchemeForReduction( @@ -1441,7 +1531,7 @@ Status IrEmitterUnnested::EmitRowReduction( // for (element_id_in_tile : range(x_tile_size)) { // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) - // partial_result = reducer(partial_result, input[z][y][z]); + // partial_result = reducer(partial_result, input[z][y][x]); // } // AtomicReducer(&output[y], partial_result); // } @@ -1495,10 +1585,11 @@ Status IrEmitterUnnested::EmitRowReduction( // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; // ++element_id_in_z_tile) { // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // int tx = x; // for (int element_id_in_x_tile = 0; // element_id_in_x_tile < x_tile_size; - // ++element_id_in_x_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // ++element_id_in_x_tile, tx += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][tx]); // } // } // } else { @@ -1506,10 +1597,11 @@ Status IrEmitterUnnested::EmitRowReduction( // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; // ++element_id_in_z_tile) { // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // int tx = x; // for (int element_id_in_x_tile = 0; element_id_in_x_tile < - // x_tile_size; ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // x_tile_size; ++element_id_in_tile, tx += warpSize) { + // if (tx < width) + // partial_result = Reducer(partial_result, input[z][y][tx]); // } // } // } @@ -1531,9 +1623,21 @@ Status IrEmitterUnnested::EmitRowReduction( // the use of shfl_down is valid. const int64 width_in_tiles = RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + llvm::Type* index_ty = GetIndexTypeForKernel( + reduce, + launch_dimensions.block_count() * launch_dimensions.threads_per_block(), + &ir_builder_); + + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { - // Emit the loop body that reduces one z-x-tile. const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); @@ -1542,8 +1646,9 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); @@ -1552,20 +1657,23 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - llvm::Value* warp_id = ir_builder_.CreateUDiv( - x_tile, ir_builder_.getInt64(kWarpSize), "warp_id"); - llvm::Value* lane_id = ir_builder_.CreateURem( - x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); + + x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty); + + llvm::Value* warp_id = + ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id"); + llvm::Value* lane_id = + ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); llvm::Value* last_x = ir_builder_.CreateNSWAdd( lane_id, ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), + index_typed_const(kWarpSize), ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(x_tile_size - 1), + index_typed_const(x_tile_size - 1), ir_builder_.CreateNSWMul( - warp_id, ir_builder_.getInt64(x_tile_size))))); + warp_id, index_typed_const(x_tile_size))))); KernelSupportLibrary ksl( &ir_builder_, @@ -1578,31 +1686,31 @@ Status IrEmitterUnnested::EmitRowReduction( int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { llvm::Value* z = ir_builder_.CreateNSWAdd( - z_indvar, ir_builder_.CreateNSWMul( - ir_builder_.getInt64(z_tile_size), z_tile)); - + z_indvar, + ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", - /*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1, - [&](llvm::Value* x_indvar) -> Status { + /*start=*/index_typed_const(0), + /*end=*/index_typed_const(x_tile_loop_bound), + /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); llvm::Value* x = ir_builder_.CreateNSWAdd( lane_id, ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), + index_typed_const(kWarpSize), ir_builder_.CreateNSWAdd( - x_indvar, - ir_builder_.CreateNSWMul( - warp_id, ir_builder_.getInt64(x_tile_size))))); + x_indvar, ir_builder_.CreateNSWMul( + warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); // Unless we know the x-tile is entirely in bounds, we have to // emit a x-in-bounds check before reading from the input. if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = - llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT( - x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); + llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(x, index_typed_const(width)), + "x_in_bounds", &ir_builder_); // Points ir_builder_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, &ir_builder_); @@ -1657,13 +1765,14 @@ Status IrEmitterUnnested::EmitRowReduction( }; return ksl.For("z_tile", - /*start=*/0, /*end=*/z_tile_size, /*step=*/1, - emit_z_tile_element_loop); + /*start=*/index_typed_const(0), + /*end=*/index_typed_const(z_tile_size), + /*step=*/1, emit_z_tile_element_loop); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); + ir_builder_.CreateICmpULT(last_x, index_typed_const(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1717,7 +1826,7 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); @@ -1731,26 +1840,23 @@ Status IrEmitterUnnested::EmitRowReduction( reduce_output_shapes[i]), &ir_builder_), &ir_builder_, "output_element_address"); - if (x_tile_size * z_tile_size < depth * width) { - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i])); - } else { + // We don't need to emit atomic operations if there is only one tile of + // results. 'depth' is the z dimension, 'width' is the x dimension. + if (z_tile_size >= depth && x_tile_size >= width) { TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {output_address, partial_reduction_result_addresses[i]}, output_address)); + } else { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); } } return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), - {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1758,7 +1864,7 @@ Status IrEmitterUnnested::EmitRowReduction( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); } // Figures out whether `reduce` is a row or column reduction, and which @@ -1871,9 +1977,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that // initializes the output array to the initial value of the reduce. - if (IsReductionToVector(*reduce) && - // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits - 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + if (IsReductionToVector(*reduce)) { TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, BuildInitializerThunk(reduce)); std::vector> thunks; @@ -1958,6 +2062,14 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "Dilation for SelectAndScatter not implemented on GPU."); } + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + source->shape(), ir_emitter_context_->device_description()); + llvm::Type* index_type = GetIndexTypeForKernel( + select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + // kSelectAndScatter is implemented as two kernel launches: the first launch // initializes the output array to the given initial value, // and the second accumulates the "source" matrix to the @@ -1988,8 +2100,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "selected_value_address", &ir_builder_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), - "selected_index_address", &ir_builder_); + index_type, index_typed_const(rank), "selected_index_address", + &ir_builder_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); ir_builder_.CreateStore(ir_builder_.getInt1(false), @@ -1997,7 +2109,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), - &ir_builder_); + &ir_builder_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -2011,17 +2123,17 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm_ir::IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); + source_index[i], index_typed_const(window.dimensions(i).stride())); operand_index[i] = ir_builder_.CreateNSWSub( ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); + index_typed_const(window.dimensions(i).padding_low())); llvm::Value* index_condition = ir_builder_.CreateICmpULT( operand_index[i], - ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + index_typed_const(ShapeUtil::GetDimension(operand->shape(), i))); in_bounds_condition = ir_builder_.CreateAnd(in_bounds_condition, index_condition); } @@ -2093,7 +2205,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index selected_index; + llvm_ir::IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( selected_index_address, {ir_builder_.getInt32(i)}); @@ -2111,8 +2223,6 @@ Status IrEmitterUnnested::HandleSelectAndScatter( source_value_address); }; - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - source->shape(), ir_emitter_context_->device_description()); UpdateLaunchDimensions( launch_dimensions, // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk @@ -2123,7 +2233,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, source->shape(), launch_dimensions, &ir_builder_) - .EmitLoop(IrName(select_and_scatter)); + .EmitLoop(IrName(select_and_scatter), index_type); } Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { @@ -2205,7 +2315,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -Status IrEmitterUnnested::HandleGenerateToken(HloInstruction* gen_token) { +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } @@ -2332,11 +2442,6 @@ GetHloBufferSlices(const HloInstruction* hlo, return slices; } -Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { - // TODO(b/72710576): Gather is not implemented on GPUs - return Unimplemented("Gather is not implemented on GPUs."); -} - std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst, int unroll_factor) { const BufferAssignment& buffer_assn = @@ -2462,17 +2567,14 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); - std::vector tuple_element_buffers; - for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { - BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() - .GetUniqueSlice(inst, {i}) - .ConsumeValueOrDie(); - tuple_element_buffers.push_back(buffer); - } - - return MakeUnique( - tuple_element_buffers, - /*destination_buffer=*/GetAllocationSlice(*inst), inst); + ShapeTree slices(inst->shape()); + slices.ForEachMutableElement( + [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) { + *slice = ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(inst, index) + .ConsumeValueOrDie(); + }); + return MakeUnique(slices, inst); } namespace { @@ -2609,14 +2711,15 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by // repeating the literal 4 or 2 times, so long as the destination buffer is // an even multiple of 32 bits long. + const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index); if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { uint16 pattern16; if (num_bytes == 1) { uint8 b = literal_bytes.front(); pattern16 = uint16{b} | (uint16{b} << 8); } else { - pattern16 = literal_bytes.front(); + memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); return {MakeUnique( @@ -2838,7 +2941,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &ir_builder_, unroll_factor) - .EmitLoop(IrName(&hlo)); + .EmitLoop(IrName(&hlo), + GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), + &ir_builder_)); } // For multiple outputs fusion, we need to emit each operand and the root. @@ -2846,10 +2951,12 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_, - unroll_factor) - .EmitLoop(IrName(&hlo))); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, + &ir_builder_, unroll_factor) + .EmitLoop(IrName(&hlo), + GetIndexTypeForKernel( + &hlo, launch_dimensions.launch_bound(), &ir_builder_))); std::vector tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index d228be81d47906850fa98e22a1d974500a7d34ed..819060061a9b8bcf0db4f782852b0a7c6530143c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -67,7 +67,6 @@ class IrEmitterUnnested : public IrEmitter { Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; @@ -77,7 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleGenerateToken(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* gen_token) override; Status EmitTargetElementLoop( const HloInstruction& hlo, diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index d541776f00ca9c0986fecd272930e5585852f6f3..ea661b3c2cb2c945297ac2098cd1c4009b2e966d 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,9 +23,11 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -69,6 +71,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // In that case, the operand of the reduce needs to have the same shape // as the other tuple operands, but also we need to compare the output // shapes of the reduces. + // TODO(tjoerg): Allow differences in fp precision. auto* element_instr_1 = get_element_instr(instr1); auto* element_instr_2 = get_element_instr(instr2); if (element_instr_1->opcode() == HloOpcode::kReduce && @@ -82,31 +85,35 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, } namespace { -bool IsReduction(HloInstruction* instr) { +bool IsInputFusibleReduction(HloInstruction* instr) { if (instr->IsMultiOutputFusion()) { for (const HloInstruction* operand : instr->fused_expression_root()->operands()) { if (operand->opcode() == HloOpcode::kReduce) { + CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput) + << " Reduce multi-output fusion " << instr->ToString() + << " must be an input fusion."; return true; } } return false; } else if (instr->opcode() == HloOpcode::kFusion) { - return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + // The loop emitter can handle to-vector reduce fusions. Such reduce + // fusions have the fusion kind kLoop rather than kInput. We do not fuse + // to-vector reduce fusions, because the resulting fusions may no longer be + // supported by loop emitter. + return IsReductionToVector(*instr->fused_expression_root()); } else { - return instr->opcode() == HloOpcode::kReduce; + return IsReductionToVector(*instr); } } } // namespace bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { // We can fuse reduces and loop fusions. - return IsReduction(instr) || + return IsInputFusibleReduction(instr) || (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop && - // TODO(b/110202584): bitcasts make nested fusions, GPU has no support - // for nested fusions. - instr->fused_expression_root()->opcode() != HloOpcode::kBitcast); + instr->fusion_kind() == HloInstruction::FusionKind::kLoop); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -147,5 +154,110 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop; } +bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { + bool changed = false; + RecomputeReachability(); + + tensorflow::gtl::FlatSet to_fuse; + // Keep a list of the instructions to fuse after making all the fusion + // decisions. We first aggressively add instructions to potential_fusion_list, + // then filter out instructions that will be no longer fusable because of + // reachability change. This avoids recalculating reachability on a large set + // of instructions. + std::vector> + potential_fusion_list; + std::vector> fusion_list; + std::vector instrs_to_update_reachability; + + // For each reduce or reduce multi-output fusion, try to fuse it with loop + // fusions operands. + for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) { + if (consumer->user_count() == 0) { + continue; + } + if (!IsInputFusibleReduction(consumer)) { + continue; + } + + auto consumer_operands = consumer->operands(); + for (size_t i = 0; i < consumer_operands.size(); ++i) { + HloInstruction* producer = consumer_operands[i]; + if (!producer->IsFusable()) { + continue; + } + const bool is_loop_fusion = + producer->opcode() == HloOpcode::kFusion && + producer->fusion_kind() == HloInstruction::FusionKind::kLoop; + if (!is_loop_fusion) { + continue; + } + if (!ShapesCompatibleForFusion(producer, consumer)) { + continue; + } + // If we have already decided to fuse this producer, skip it. + if (ContainsKey(to_fuse, producer)) { + continue; + } + // Do not fuse a producer if the other operands of the fusion are + // reachable from the producer, this would create a cycle. + if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + break; + } + to_fuse.insert(producer); + potential_fusion_list.emplace_back(producer, consumer); + instrs_to_update_reachability.push_back(producer); + instrs_to_update_reachability.push_back(consumer); + break; + } + } + + // Filter out pairs that will be no longer fusable because of reachability + // change. + for (auto& fusion_pair : potential_fusion_list) { + HloInstruction* producer = fusion_pair.first; + HloInstruction* consumer = fusion_pair.second; + if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + UpdateReachability(producer, consumer, instrs_to_update_reachability); + fusion_list.push_back(fusion_pair); + } + } + + for (auto fusions_to_create : fusion_list) { + HloInstruction* producer = fusions_to_create.first; + HloInstruction* consumer = fusions_to_create.second; + if (consumer->opcode() != HloOpcode::kFusion) { + // Fusing with a reduce (fusion) always results in an input fusion. + HloInstruction* input_fusion = + computation()->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), HloInstruction::FusionKind::kInput, consumer)); + VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " + << consumer->name() << " into " << input_fusion->name(); + TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion)); + if (producer->opcode() == HloOpcode::kFusion) { + input_fusion->MergeFusionInstructionIntoMultiOutput(producer); + } else { + input_fusion->FuseInstructionIntoMultiOutput(producer); + } + } else { + VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " + << consumer->name(); + + if (producer->opcode() == HloOpcode::kFusion) { + consumer->MergeFusionInstructionIntoMultiOutput(producer); + } else { + consumer->FuseInstructionIntoMultiOutput(producer); + } + } + changed = true; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 16db0e0f02d5cbf582f0e4236297b3d5407014b3..67ca5d49eee8508e93284b134f8410eb3a89f9ce 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -45,6 +45,9 @@ class GpuMultiOutputFusion : public MultiOutputFusion { // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override; + + // Fuse loop fusions into reduce fusions. + bool DoProducerConsumerMultiOutputFusion() override; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 5e7ceb7976b5d1957f706c12ec255e93991344b8..979ea79243818c398b1b130254a41c95ced51830 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -255,5 +255,99 @@ TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_add { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add + reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Add())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Select())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_element_wise { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add_computation + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index d8c07dc3119fb81a3ef22822acb11b7c4d5bbca5..cd833ec7bd858aabee84ac306d198e80eb112506 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -58,7 +58,7 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -71,14 +71,13 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // // %nctaid.x is currently specified as 2147483647. VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_; + CHECK_NE(index_type, nullptr); std::vector array_indices; - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), static_cast(block_id)); - block_id = - ir_builder_->CreateZExt(block_id, ir_builder_->getInt64Ty(), "block_id"); + block_id = ir_builder_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" @@ -88,13 +87,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), static_cast(thread_id)); - thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(), - "thread_id"); + thread_id = + ir_builder_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); llvm::Value* linear_index_base = ir_builder_->CreateAdd( ir_builder_->CreateMul( block_id, - ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "", + llvm::ConstantInt::get(index_type, + launch_dimensions_.threads_per_block()), + "", /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); @@ -110,21 +111,23 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Intrinsic::assume, {ir_builder_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get(index_type, + launch_dimensions_.threads_per_block() * + launch_dimensions_.block_count()), "linear_index_in_range")}, {}, ir_builder_); if (unroll_factor_ > 1) { linear_index_base = ir_builder_->CreateMul( - linear_index_base, ir_builder_->getInt64(unroll_factor_), + linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_), "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } array_indices.emplace_back(linear_index_base, shape_, ir_builder_); for (int i = 1; i < unroll_factor_; ++i) { llvm::Value* linear_index = ir_builder_->CreateAdd( - linear_index_base, ir_builder_->getInt64(i), "linear_index", + linear_index_base, llvm::ConstantInt::get(index_type, i), + "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); array_indices.emplace_back(linear_index, shape_, ir_builder_); } @@ -132,7 +135,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( auto if_in_bounds = llvm_ir::EmitIfThenElse( ir_builder_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), + llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))), llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); // Set exit_bb_ to the exit block of the if structure. diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 25318b3bed8bf4a2dfe3a4a974269d0405c3bfec..302e1bf1bc8e90f2eebd838f156a1552e86185ac 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index c125474edb1036090a926020f2b1e7fcf64c751a..02471129e004b4876ce20a62cade34060c65b478 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -47,6 +47,7 @@ class LaunchDimensions { int64 block_count() const { return block_count_; } int64 threads_per_block() const { return threads_per_block_; } + int64 launch_bound() const { return block_count() * threads_per_block(); } private: int64 block_count_; diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 8218f4fd11d3978d0ecc53fc15e287aea4b69ec3..39a6a38d001f502b2abb8de6efe2ce623b478c71 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index a04aa4069d2344ca7b2e763cfeeb53abcbefc21d..4005fc0d114a3ec7a38dfb5edecdaeb1e8497ade 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -230,6 +230,9 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. + + // Make sure each buffer get reused at most once. + FlatSet reused_buffers; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -242,6 +245,9 @@ Status HeapSimulator::RunComputation( bool shared = false; if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { + if (reused_buffers.count(operand_buffer) != 0) { + continue; + } if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && points_to_analysis.CanShareOperandBufferWithUser( @@ -251,6 +257,7 @@ Status HeapSimulator::RunComputation( << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); shared = true; + reused_buffers.insert(operand_buffer); break; } } diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 93d7a141258a3186b10cf2728b70a034488a84f2..3849b565e3136924b2d2b1929353885f85b1a043 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -198,6 +198,11 @@ class HeapSimulatorTracker { .ConsumeValueOrDie(); } + int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) { + const BufferValue* buffer = BufferAt(instruction, index); + return result_.chunk_map.at(buffer).offset; + } + // Ensures the expected sequence of Alloc/Free/Finish calls was performed. void ExpectCallSequence(const CallSequence& expected) const { EXPECT_EQ(expected, actual_calls_); @@ -209,10 +214,9 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const BufferValue* a = BufferAt(instruction_a, index_a); - const BufferValue* b = BufferAt(instruction_b, index_b); - EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) - << *a << ", " << *b; + int64 offset_a = OffsetAt(instruction_a, index_a); + int64 offset_b = OffsetAt(instruction_b, index_b); + EXPECT_EQ(offset_a, offset_b); } private: @@ -311,6 +315,43 @@ TEST_F(HeapSimulatorTest, MultiplyAdd) { tracker.ExpectSharedBuffers(add, {}, mul, {}); } +TEST_F(HeapSimulatorTest, BufferReusedOnce) { + HeapSimulatorTracker tracker(TestName()); + auto builder = HloComputation::Builder(TestName()); + + HloComputation::Builder fusion_builder("fusion"); + { + HloComputation::Builder& builder = fusion_builder; + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32vec4_, "A")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param)); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); + + builder.AddInstruction(HloInstruction::CreateTuple({exp, neg})); + } + auto fusion_computation = + tracker.module()->AddEmbeddedComputation(fusion_builder.Build()); + auto a_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}), + HloInstruction::FusionKind::kLoop, {neg}, fusion_computation)); + tracker.module()->AddEntryComputation(builder.Build()); + + tracker.RunWholeModule({a_param, neg, fusion}); + + auto neg_buffer = tracker.OffsetAt(neg, {}); + int64 output_buffer_0 = tracker.OffsetAt(fusion, {0}); + int64 output_buffer_1 = tracker.OffsetAt(fusion, {1}); + // Only one buffer should be shared. + EXPECT_TRUE((neg_buffer == output_buffer_0) ^ + (neg_buffer == output_buffer_1)); +} + TEST_F(HeapSimulatorTest, MultiplyDot) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e201359d3d25b7d2dda852762c6de1fcb75685d7..d2417910606fdd13223076d33ff1bda1dd291d98 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -145,12 +145,16 @@ message HloInstructionProto { repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; - repeated int64 replica_group_ids = 44; xla.OpSharding sharding = 40; // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; + + // Cross Replica Sum fields. + repeated int64 replica_group_ids = 44; + int64 all_reduce_id = 45; + string cross_replica_sum_barrier = 46; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0a948cc390fed7daed3e0cc938bf59cbcfd9b4df..e8a4b034b4396860bd5873f43003844ce92dea6c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -452,15 +452,16 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr> HloAliasAnalysis::Run( - HloModule* module) { + HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer) { VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); - TF_ASSIGN_OR_RETURN( - alias_analysis->dataflow_analysis_, - HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, - /*bitcast_defines_value=*/false)); + TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, + /*bitcast_defines_value=*/false, + fusion_can_share_buffer)); BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 67dfd4301b3a027a496911ecf6f06841dfd6423a..afb0c20f0cdf3eb92f72ab8bc368b4b8d723459e 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -39,7 +39,10 @@ class HloAliasAnalysis { public: // The callgraph of the given HloModule must be flattened // (xla::FlattenCallGraph) prior to running the analysis. - static StatusOr> Run(HloModule* module); + static StatusOr> Run( + HloModule* module, + const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr); string ToString() const; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ef8bb030fbc7a99e1fc907c0b1c1e9b0a16ecbd1..34b18b0e21fbf6ce5d406cae9dbd64b9744f5a83 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } +namespace { + +// Returns the new name for a fusion parameter when we change its number. +// +// Fusion parameters are named foo.param_1, bar.param_2, etc. We are +// renumbering the parameters, so replace the final number in the name with +// the updated value. +string RenameFusionParameter(const string& original_name, int64 new_param_no) { + const string param_underscore = ".param_"; + size_t index = original_name.rfind(param_underscore); + if (index == string::npos) { + return original_name; + } + string after_param = original_name.substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + return StrCat(original_name.substr(0, index + param_underscore.size()), + new_param_no); + } + return original_name; +} + +} // namespace + Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->name(); - // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters, so replace the final number in the name with - // the updated value. - const string param_underscore = ".param_"; - size_t index = param_name.rfind(param_underscore); - if (index == string::npos) { - string after_param = name().substr(index + param_underscore.size()); - int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { - param_name = - StrCat(param_name.substr(0, index), param_underscore, param_no); - } - } - + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); @@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } +Status HloComputation::RemoveUnusedParameters() { + CHECK(IsFusionComputation()); + int64 removed = 0; + for (int64 i = 0; i < param_instructions_.size(); ++i) { + HloInstruction* param_instruction = param_instructions_[i]; + if (param_instruction->user_count() == 0 && + param_instruction != root_instruction()) { + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + ++removed; + continue; + } + + if (removed > 0) { + const int64 param_no = i - removed; + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + } + } + param_instructions_.resize(param_instructions_.size() - removed); + return Status::OK(); +} + bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -263,46 +302,11 @@ void HloComputation::set_root_instruction( namespace { -// Helper class which computes the post order of an expression rooted at a -// particular instruction. -class InstructionPostOrderer : public DfsHloVisitorWithDefault { - public: - // added_instructions is the set of instructions which have already been - // accounted for in the post order in previous invocations of - // GetOrder. Without this mechanism, instructions which are predecessors of - // multiple root instructions of the computation can be added to the post - // order more than once. - static std::list GetOrder( - HloInstruction* root, - tensorflow::gtl::FlatSet* added_instructions) { - InstructionPostOrderer orderer(added_instructions); - TF_CHECK_OK(root->Accept(&orderer)); - return std::move(orderer.post_order_); - } - - private: - explicit InstructionPostOrderer( - tensorflow::gtl::FlatSet* added_instructions) - : added_instructions_(added_instructions) {} - ~InstructionPostOrderer() override {} - - Status DefaultAction(HloInstruction* hlo_instruction) override { - if (added_instructions_->count(hlo_instruction) == 0) { - post_order_.push_back(hlo_instruction); - added_instructions_->insert(hlo_instruction); - } - return Status::OK(); - } - - std::list post_order_; - tensorflow::gtl::FlatSet* added_instructions_; -}; - // Helper which builds a post order of the HLO call graph. void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet* visited, - std::list* post_order) { + std::vector* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -314,48 +318,53 @@ void ComputeComputationPostOrder( } } -std::list ComputeInstructionPostOrder( - HloInstruction* root, tensorflow::gtl::FlatSet* visited) { - std::list post_order; - std::vector> dfs_stack; - dfs_stack.emplace_back(root, false); +enum State { kVisiting, kVisited }; + +void ComputeInstructionPostOrder( + std::vector* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap* visited) { + std::vector dfs_stack; + dfs_stack.push_back(root); while (!dfs_stack.empty()) { const auto current = dfs_stack.back(); - if (current.second) { - dfs_stack.pop_back(); - if (!visited->insert(current.first).second) { - continue; - } - post_order.push_back(current.first); - } else { - if (visited->count(current.first)) { + auto it = visited->find(current); + if (it != visited->end()) { + if (it->second == kVisited) { + // Already visited. dfs_stack.pop_back(); continue; } - dfs_stack.back().second = true; - - // Add the operands to the stack in reverse order so the first operand is - // processed first. This will produce a more natural ordering and a nicer - // result for thigns like HLO stringification. - const auto& operands = current.first->operands(); - for (int64 i = operands.size() - 1; i >= 0; --i) { - dfs_stack.emplace_back(operands[i], false); - } + // Visit this node. + CHECK_EQ(kVisiting, it->second); + dfs_stack.pop_back(); + post_order->push_back(current); + it->second = kVisited; + continue; + } - for (HloInstruction* op : current.first->control_predecessors()) { - dfs_stack.emplace_back(op, false); - } + visited->insert({current, kVisiting}); + + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for thigns like HLO stringification. + const auto& operands = current->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + dfs_stack.emplace_back(operands[i]); + } + + for (HloInstruction* op : current->control_predecessors()) { + dfs_stack.emplace_back(op); } } - return post_order; } } // namespace -std::list HloComputation::MakeInstructionPostOrder() const { - std::list post_order; - std::list trace_instructions; - tensorflow::gtl::FlatSet added_instructions; +std::vector HloComputation::MakeInstructionPostOrder() const { + std::vector post_order; + post_order.reserve(instruction_count()); + std::vector trace_instructions; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -363,21 +372,20 @@ std::list HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice( - post_order.end(), - ComputeInstructionPostOrder(instruction.get(), &added_instructions)); + ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); } } - post_order.splice(post_order.end(), trace_instructions); + post_order.insert(post_order.end(), trace_instructions.begin(), + trace_instructions.end()); CHECK_EQ(instructions_.size(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } -std::list HloComputation::MakeEmbeddedComputationsList() +std::vector HloComputation::MakeEmbeddedComputationsList() const { tensorflow::gtl::FlatSet visited; - std::list post_order; + std::vector post_order; // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after @@ -648,7 +656,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { - const std::list all = MakeInstructionPostOrder(); + const auto& all = MakeInstructionPostOrder(); auto result = MakeUnique(all); std::vector inputs; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0da4a305f3d5d694a1918fed294337100b0a27fd..c1c3e79ebc789eff0873515c5fffd11089b92043 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -113,6 +113,11 @@ class HloComputation { // instruction. Status RemoveParameter(int64 param_no); + // Remove unused parameters from the computation. + // Note this is only applicatable to the computation for the fusion + // instruction. + Status RemoveUnusedParameters(); + // Add new parameter instruction to the computation. // This should be a new parameter. Instruction will be appended to parameters // and inserted to the instruction list. @@ -199,7 +204,7 @@ class HloComputation { // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. - std::list MakeInstructionPostOrder() const; + std::vector MakeInstructionPostOrder() const; // Computes and returns the reachability between HLO instructions in the // computation. The returned HloReachabilityMap is constructed such that @@ -221,7 +226,7 @@ class HloComputation { // transitively. The embedded computations are sorted such that if computation // A calls computation B (eg, via a map instruction) then A will appear after // B in the list. - std::list MakeEmbeddedComputationsList() const; + std::vector MakeEmbeddedComputationsList() const; // Creates a fusion instruction containing the given instructions. // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 3f59d31bb9123a480864ddfca939ec3c032298c9..a8f3f0e9c2dca8fb97ebc8f8c9dd80fcf7f4de4a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -375,20 +375,20 @@ TEST_F(HloComputationTest, DeepCopyToken) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); // No copy should be added. - EXPECT_THAT(copy, op::GenerateToken()); + EXPECT_THAT(copy, op::AfterAll()); } TEST_F(HloComputationTest, DeepCopyTokenTuple) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto tuple = @@ -417,6 +417,9 @@ TEST_F(HloComputationTest, CycleDetection) { // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); + auto instructions = computation->MakeInstructionPostOrder(); + EXPECT_EQ(3, instructions.size()); + const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; auto visit_status = computation->Accept(visitor); ASSERT_FALSE(visit_status.ok()); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 762e1afc71b108b2e32b5a7f7f1bbeb783fc6fbd..8955e26d5cd1bf30f965395750f5078d070a6906 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -393,7 +393,7 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { +Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 0d66736fe1d0677d13a63ede7a203d6ac20c76f5..44e5df587c4bf0b3004c8d624c45d42d258c3661 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,7 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; - Status HandleGenerateToken(const HloInstruction* token) override; + Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index d22bef56730da194816b4ee89dc3196439b350f9..9fc4c48226fa5307f5e030a612f3957756827e37 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -59,9 +59,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a unary user function: x => exp(x + 0.5) { XlaBuilder builder("add_and_exp"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto half = builder.ConstantR0(0.5); - builder.Exp(builder.Add(x, half)); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto half = ConstantR0(&builder, 0.5); + Exp(Add(x, half)); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); add_and_exp_ = computation_status.ConsumeValueOrDie(); @@ -70,9 +70,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); add_ = computation_status.ConsumeValueOrDie(); @@ -81,9 +81,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { XlaBuilder builder("sigmoid"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = builder.ConstantR0(1.0); - builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&builder, 1.0); + Div(one, Add(one, Exp(Neg(x)))); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); sigmoid_ = computation_status.ConsumeValueOrDie(); @@ -92,9 +92,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { XlaBuilder builder("max"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Max(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); max_ = computation_status.ConsumeValueOrDie(); @@ -103,9 +103,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { XlaBuilder builder("gt"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Gt(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Gt(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); gt_ = computation_status.ConsumeValueOrDie(); @@ -137,9 +137,9 @@ class HloCostAnalysisTest : public ::testing::Test { TEST_F(HloCostAnalysisTest, MatrixMultiply) { XlaBuilder builder("matrix_multiply"); - auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); - auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); - auto result = builder.Dot(lhs, rhs); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + Dot(lhs, rhs); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -159,8 +159,8 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); - auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); - auto result = builder.Map({input}, add_and_exp_, {0}); + auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in"); + Map(&builder, {input}, add_and_exp_, {0}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -176,17 +176,17 @@ TEST_F(HloCostAnalysisTest, Map) { TEST_F(HloCostAnalysisTest, Convolution) { XlaBuilder builder("convolution"); - auto input = builder.Parameter( - 0, + auto input = Parameter( + &builder, 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, /*x_dim=*/20}), "input"); - auto kernel = builder.Parameter( - 1, + auto kernel = Parameter( + &builder, 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, /*x_dim=*/3}), "kernel"); - auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); + Conv(input, kernel, {1, 1}, Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -206,9 +206,8 @@ TEST_F(HloCostAnalysisTest, Convolution) { TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = - builder.Reduce(input, builder.ConstantR0(0.0f), add_, {1}); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + Reduce(input, ConstantR0(&builder, 0.0f), add_, {1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -224,9 +223,9 @@ TEST_F(HloCostAnalysisTest, Reduce) { TEST_F(HloCostAnalysisTest, ReduceWindow) { XlaBuilder builder("reduce_window"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, - {4, 5}, {4, 5}, Padding::kValid); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + ReduceWindow(input, ConstantR0(&builder, 0), add_, {4, 5}, {4, 5}, + Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -241,12 +240,11 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { TEST_F(HloCostAnalysisTest, SelectAndScatter) { XlaBuilder builder("select_and_scatter"); auto operand = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); - auto result = - builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, - source, builder.ConstantR0(0), add_); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); + SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source, + ConstantR0(&builder, 0), add_); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -261,7 +259,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { TEST_F(HloCostAnalysisTest, Broadcast) { XlaBuilder b("broadcast"); - b.Broadcast(b.ConstantR0(42), {10, 7}); + Broadcast(ConstantR0(&b, 42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( @@ -273,13 +271,12 @@ TEST_F(HloCostAnalysisTest, Broadcast) { TEST_F(HloCostAnalysisTest, FullyConnectedForward) { XlaBuilder builder("fully_connected_forward"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); - auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); + auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias"); // sigmoid(input * weight + bias) - auto result = builder.Map( - {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); + Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -297,11 +294,11 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { XlaBuilder builder("conv_looking_matmul"); - auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), - "input"); - auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), - "weights"); - builder.Conv(lhs, rhs, {1, 1}, Padding::kSame); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "input"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "weights"); + Conv(lhs, rhs, {1, 1}, Padding::kSame); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( &conv_analysis)); @@ -311,10 +308,10 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { { XlaBuilder builder("matmul"); auto lhs = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); - builder.Dot(lhs, rhs); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); + Dot(lhs, rhs); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( &matmul_analysis)); @@ -419,9 +416,9 @@ TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { XlaBuilder builder("matmul"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); - auto tuple = builder.Tuple({x, y}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); + Tuple(&builder, {x, y}); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK( @@ -435,21 +432,21 @@ TEST_F(HloCostAnalysisTest, TupleCost) { TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); - auto input = builder.Parameter( - 0, + auto input = Parameter( + &builder, 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, /*x_dim=*/20}), "input"); - auto kernel = builder.Parameter( - 1, + auto kernel = Parameter( + &builder, 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, /*x_dim=*/3}), "kernel"); - auto result = builder.ConvGeneralDilated( - input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, - /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - XlaBuilder::CreateDefaultConvDimensionNumbers(2)); + ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1}, + /*padding=*/{{1, 1}, {1, 1}}, + /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -463,8 +460,8 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { TEST_F(HloCostAnalysisTest, Slice) { // Test the analysis on a slice. XlaBuilder builder("slice"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.Slice(x, {0}, {1}, {1}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + Slice(x, {0}, {1}, {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -478,8 +475,8 @@ TEST_F(HloCostAnalysisTest, Slice) { TEST_F(HloCostAnalysisTest, DynamicSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-slice"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.DynamicSlice(x, builder.ConstantR1({1}), {1}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + DynamicSlice(x, ConstantR1(&builder, {1}), {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -493,9 +490,9 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) { TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-update-slice"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1({1.0}), - builder.ConstantR1({1})); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + DynamicUpdateSlice(x, ConstantR1(&builder, {1.0}), + ConstantR1(&builder, {1})); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d0200058683b2db8f5f0469d6c643014881f741e..8a4a9b59868eb436842c9a819ffa8d6ec2054eee 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -34,16 +34,86 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { + +// We have this pattern in dynamaic update slice fusion, which should be +// supported: +// +// Parameters: p0, p1 +// Fusion +// ds = DynamicSlice(p0, p1) +// ROOT DynamicUpdateslice(p0, ds, p1) +// +// In this case, we should be able to reuse p0 and output, although p0 has +// multiple uses. +bool MultiDynamicSliceUseShareSameIndices( + tensorflow::gtl::ArraySlice uses) { + if (uses.empty()) { + return false; + } + const HloInstruction* indices = nullptr; + for (HloUse use : uses) { + auto user = use.instruction; + if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + if (indices == nullptr) { + indices = user->operand(2); + } else if (indices != user->operand(2)) { + return false; + } + if (use.operand_number != 0) { + return false; + } + } else if (user->opcode() == HloOpcode::kDynamicSlice) { + if (indices == nullptr) { + indices = user->operand(1); + } else if (indices != user->operand(1)) { + return false; + } + } else { + return false; + } + } + return true; +} + +} // namespace using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value) +HloDataflowAnalysis::HloDataflowAnalysis( + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(&module)) {} + call_graph_(CallGraph::Build(&module)), + fusion_can_share_buffer_(fusion_can_share_buffer) {} + +bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( + const HloInstruction* inst) { + tensorflow::gtl::FlatSet visited; + tensorflow::gtl::InlinedVector stack; + stack.push_back(inst); + while (!stack.empty()) { + const HloInstruction* current = stack.back(); + stack.pop_back(); + visited.insert(current); + for (const HloInstruction* user : current->users()) { + // Found a user that is non-elementwise on current instruction. + for (const int64 use_index : user->OperandIndices(current)) { + if (!user->IsElementwiseOnOperand(use_index) && + user->opcode() != HloOpcode::kTuple) { + return false; + } + } + if (!visited.count(user)) { + stack.push_back(user); + } + } + } + return true; +} bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -396,6 +466,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { return changed; } +bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) { + // Domain instructions just forward their operand. Given that domains can have + // a tuple operand, we iterate through its indexes, like for copies. + // Unlike copies though we also propagate the top-level value. + CHECK_EQ(domain->opcode(), HloOpcode::kDomain); + bool changed = false; + for (auto& pair : GetInstructionValueSet(domain)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); bool changed = false; @@ -556,6 +644,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateBitcastValueSet(instruction); case HloOpcode::kSlice: return UpdateSliceValueSet(instruction); + case HloOpcode::kDomain: + return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: @@ -734,6 +824,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: + case HloOpcode::kDomain: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. break; @@ -787,12 +878,13 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr> HloDataflowAnalysis::Run( - const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) { VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique( - new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); + auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); dataflow_analysis->Propagate(); @@ -915,6 +1007,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( ShapeUtil::GetSubshape(operand->shape(), operand_index); const Shape& user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; @@ -927,11 +1020,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); if (value.uses().size() != 1) { + if (MultiDynamicSliceUseShareSameIndices(value.uses())) { + return true; + } return false; } const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { // Loop fusion with kDynamicUpdateSlice fused root. @@ -941,6 +1038,8 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // index 0. return use.instruction == user->fused_expression_root() && use.operand_number == 0; + } else { + return AreTransitiveUsesElementwiseOrTuple(fusion_param); } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { @@ -966,6 +1065,9 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return use.instruction == user->fused_expression_root() && use.operand_number == other_add_operand_index; + } else if (fusion_can_share_buffer_ != nullptr && + fusion_can_share_buffer_(user, operand)) { + return true; } } @@ -1003,9 +1105,6 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Loop fusions that contain transposing copies won't reach here as they have // different layouts, which fails the check in the beginning of this function. - // - // Multi-output fusion will fail the check here as tuples are not considered - // an elementwise operation. return user->IsElementwiseOnOperand(user->operand_index(operand)); } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 9868746b6113881949e388cd2a4aa9f610b1fdb7..9fea218af0c4ac8a512bea5c187564a8219d041f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -42,6 +42,20 @@ namespace xla { // Analysis which identifies all HLO values and their uses in an HLO module. class HloDataflowAnalysis { public: + // Different backends can have very different ways to do fusion, so we give + // backends the flexibility to decide whether an fusion instruction can share + // buffer with it's operands. If this is not specified, a default strategy + // will be used; if this is specified, it will be applied *in addition* to the + // default strategy. + // + // The first parameter of the function should be the fusion instruction, the + // second parameter should be an operand of the fusion instruction. + // + // TODO(b/80315712): Find a better way to tell whether a fusion can share + // buffer. + using FusionCanShareBufferFunction = std::function; + // Run dataflow analysis on the given module. Parameters: // // ssa_form : If true then new values are defined at the merge points of @@ -61,7 +75,10 @@ class HloDataflowAnalysis { // value of its operand. static StatusOr> Run( const HloModule& module, bool ssa_form = false, - bool bitcast_defines_value = false); + bool bitcast_defines_value = false, + const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); + + static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); // Returns true if 'instruction' defines an HLO value at the given shape index // of its output. @@ -136,8 +153,10 @@ class HloDataflowAnalysis { const ShapeIndex& user_index) const; protected: - HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value = false); + HloDataflowAnalysis( + const HloModule& module, bool ssa_form, + bool bitcast_defines_value = false, + const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); // Returns a new HloValue defined at the given instruction and shape index. HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, @@ -166,6 +185,7 @@ class HloDataflowAnalysis { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); + bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); bool UpdateRecvDoneValueSet(HloInstruction* recv_done); @@ -221,6 +241,10 @@ class HloDataflowAnalysis { // The Id to use for the next HloValue. HloValue::Id next_value_id_ = 0; + + // Backend specific function that decides whether a fusion can share buffer + // with its operand. + FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index db1822ec47a7f52e2c3ef8dcbf433cd787ef75ab..0ea8bdcab680a40fd9301f2dcd5e0e176ac73d15 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1880,9 +1880,14 @@ class HloDataflowAnalysisTestBase : public HloTestBase { computation_ = module_->AddEntryComputation(std::move(computation)); } - void RunAnalysis() { + void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr) { CHECK_NOTNULL(module_.get()); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + dataflow_analysis_ = + HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false, + /*bitcast_defines_value=*/false, + fusion_can_share_buffer) + .ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { @@ -1998,7 +2003,7 @@ TEST_F(CanShareOperandBufferWithUserTest, } TEST_F(CanShareOperandBufferWithUserTest, - MultiOutputFusionCantAliasOperandBuffer) { + MultiOutputFusionCanAliasOperandBuffer) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); @@ -2022,14 +2027,14 @@ TEST_F(CanShareOperandBufferWithUserTest, {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); RunAnalysis(); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, - fusion, {0})); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, - fusion, {1})); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, - fusion, {0})); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, - fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); } TEST_F(CanShareOperandBufferWithUserTest, @@ -2057,6 +2062,31 @@ TEST_F(CanShareOperandBufferWithUserTest, fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + auto index = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 0}))); + auto ds = builder.AddInstruction( + HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); + + auto dus = builder.AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dus, ds, index}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { auto builder = HloComputation::Builder(TestName()); @@ -2132,7 +2162,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { } TEST_F(CanShareOperandBufferWithUserTest, - FusedDynamicUpdateSliceWithConvertCantShare) { + FusedDynamicUpdateSliceWithConvertCanShare) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -2166,8 +2196,7 @@ TEST_F(CanShareOperandBufferWithUserTest, HloInstruction::FusionKind::kLoop); RunAnalysis(); - // The fusion instruction can't share with tuple element 1. - EXPECT_FALSE( + EXPECT_TRUE( dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); } @@ -2259,6 +2288,33 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kMultiply, operand, operand)); + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, mul}, HloInstruction::FusionKind::kInput); + RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion, + const HloInstruction*) { + return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop; + }); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index fcd723af146e2227b8661b1a4993f1338f7de389..7d35e251ca21951036336ff1a1eb4aabc87bc5ca 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -41,20 +41,13 @@ StatusOr HloDCE::Run(HloModule* module) { XLA_VLOG_LINES(2, module->ToString()); for (auto* computation : module->MakeComputationPostOrder()) { - std::unordered_set live_instructions; - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( - [&live_instructions](HloInstruction* instruction) { - live_instructions.insert(instruction); - return Status::OK(); - })); - // Remove any dead roots and their dead transitive operands. Collect them // into a separate list first to avoid problems with iterating through the // computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { - if (instruction->user_count() == 0 && - live_instructions.count(instruction) == 0 && + if (instruction != computation->root_instruction() && + instruction->user_count() == 0 && computation->IsRemovable(instruction) && !instruction->HasSideEffect()) { dead_roots.push_back(instruction); @@ -85,8 +78,7 @@ StatusOr HloDCE::Run(HloModule* module) { } // Remove dead computations. - std::list computations = module->MakeComputationPostOrder(); - for (auto* computation : computations) { + for (auto* computation : module->MakeComputationPostOrder()) { if (live_computations.count(computation) == 0) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 5a56607a665c4cbeb7b2572f182b88e890602968..2822ecd788f624ff4e289f4b2d32fb83caf8bd77 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -234,9 +234,10 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { { auto param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - - auto infeed = - body_builder.AddInstruction(HloInstruction::CreateInfeed(shape, "")); + auto token = + body_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = body_builder.AddInstruction( + HloInstruction::CreateInfeed(shape, token, "")); body_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed)); } @@ -278,8 +279,10 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { { auto param = nested_callee_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); + auto token = nested_callee_builder.AddInstruction( + HloInstruction::CreateAfterAll({})); nested_callee_builder.AddInstruction( - HloInstruction::CreateOutfeed(shape, param, "")); + HloInstruction::CreateOutfeed(shape, param, token, "")); } auto nested_called_computation = module->AddEmbeddedComputation(nested_callee_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index e0c5718509dabebb7b9307bf764b0ea1ce7369a0..eded3e78eead76c4564daee119034c5031eba409 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -26,10 +26,10 @@ limitations under the License. namespace xla { // Domain isolation is the task of placing kDomain instructions between HLO -// instructions having different shrading. A kDomain instruction is essentially +// instructions having different sharding. A kDomain instruction is essentially // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no -// kDomain instruciton will be placed. +// kDomain instruction will be placed. class HloDomainIsolator : public HloPassInterface { public: // Creates a new kDomain instruction for the edge between the use instruction diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 5553ddb153f7f1f2e6a790890c11f35e192488c4..abc5b1c8effe03e39a2683eb2876ad0a27293921 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -21,12 +21,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloTestBase { +class HloDomainTest : public HloVerifiedTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -64,11 +65,11 @@ class HloDomainTest : public HloTestBase { return false; } - StatusOr> ParseModule( - tensorflow::StringPiece hlo_string) { + StatusOr ParseModule(tensorflow::StringPiece hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return ParseHloString(hlo_string, config); + ParseAndVerifyModule(hlo_string, config); + return &module(); } }; @@ -143,32 +144,31 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); } TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { @@ -186,12 +186,11 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -212,27 +211,26 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); - EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "b", "a")); + EXPECT_TRUE(HasDomainEdge(module, "f", "e")); + EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module, "b", "a")); + EXPECT_FALSE(HasDomainEdge(module, "f", "e")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -248,12 +246,11 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -270,16 +267,15 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_FALSE(remover_changed); - HloInstruction* add = FindInstruction(module.get(), "c"); + HloInstruction* add = FindInstruction(module, "c"); ASSERT_NE(add, nullptr); auto device = add->sharding_unique_device(); EXPECT_TRUE(device.has_value()); @@ -302,42 +298,41 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator sharding_isolator(CreateShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, - sharding_isolator.Run(module.get())); + sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); HloDomainIsolator opname_isolator(OpNameDomainCreator); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module.get())); + opname_isolator.Run(module)); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module.get())); + sharding_remover.Run(module)); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module.get())); + opname_remover.Run(module)); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); } TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { @@ -345,33 +340,35 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - infeed = (f32[4], f32[4]) infeed(), - sharding={{maximal device=1}, {maximal device=0}} - gte0 = f32[4] get-tuple-element(infeed), index=0 - gte1 = f32[4] get-tuple-element(infeed), index=1 + token = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token), + sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + gte0 = f32[4] get-tuple-element(infeed.data), index=0 + gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) copy1 = f32[4] copy(gte1) ROOT add = f32[4] add(copy0, copy1) } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); - EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); - EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); - EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); + EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); // Inject unassigned tuple/gte within the infeed domain, to simulate the // HLO passes adding unexpected instructions. // // infeed + // | + // infeed.data (tuple element 0 of infeed) // / \ // GTE0 GTE1 // / \ @@ -380,31 +377,36 @@ ENTRY entry { // \ / // TUPLE // | - // DOMAIN - HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + HloInstruction* infeed = FindInstruction(module, "infeed"); ASSERT_NE(infeed, nullptr); - auto infeed_users = infeed->users(); - HloInstruction* new_gte0 = + HloInstruction* infeed_data = infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + + auto infeed_data_users = infeed_data->users(); + HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data, + 0)); HloInstruction* new_copy0 = - infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( new_gte0->shape(), HloOpcode::kCopy, new_gte0)); - HloInstruction* new_gte1 = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data, + 1)); HloInstruction* new_copy1 = - infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( new_gte1->shape(), HloOpcode::kCopy, new_gte1)); - HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction* new_tuple = infeed_data->parent()->AddInstruction( HloInstruction::CreateTuple({new_copy0, new_copy1})); - for (HloInstruction* user : infeed_users) { - TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + for (HloInstruction* user : infeed_data_users) { + TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple)); } HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); struct Assignment { @@ -418,7 +420,7 @@ ENTRY entry { }; for (auto& assignment : assignments) { auto device = assignment.instruction->sharding_unique_device(); - EXPECT_TRUE(device.has_value()); + ASSERT_TRUE(device.has_value()); EXPECT_EQ(*device, assignment.device); } EXPECT_TRUE(new_tuple->has_sharding()); @@ -428,5 +430,26 @@ ENTRY entry { HloSharding::AssignDevice(0)})); } +// Tests that text dumps of domain instructions can be parsed back, in the +// specific case of null shardings. +TEST_F(HloDomainTest, DumpParseNullSharding) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto sharding_md_0 = MakeUnique(nullptr); + auto sharding_md_1 = MakeUnique(nullptr); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( + shape, param, std::move(sharding_md_0), std::move(sharding_md_1))); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto hlo_string = module->ToString(); + ASSERT_TRUE(ParseModule(hlo_string).status().ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index 5c5a059e0fd895f03bc26a975609b57333237faf..c170e36c73ad2bef830e528de3ec72d38683d888 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -57,8 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - ROOT infeed = bf16[4]{0} infeed() - outfeed = () outfeed(infeed) + token = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token) + ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) } )"; auto module = CreateModuleFromHloString(hlo_string); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 33424019b93feff862c6e3e268ae3980bacc9142..e65e1af20c156f6b8fc16566ce548be6ce0d746b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -902,7 +902,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { +Status HloEvaluator::HandleAfterAll(HloInstruction* token) { evaluated_[token] = Literal::CreateToken(); return Status::OK(); } @@ -1068,6 +1068,19 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } +Status HloEvaluator::HandleSort(HloInstruction* sort) { + if (!ShapeUtil::IsTuple(sort->shape())) { + return DefaultAction(sort); + } + // The key-value version of Sort is a special snowflake, since the output + // shape is a tuple, so its element type is not meaningful. + // + // TODO(mkuper): Do something sane here, so that we can support different key + // and value types. + return sort->Visit( + typed_visitors_.at(sort->operand(0)->shape().element_type()).get()); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index fc2fc9437b238a2e519401b2b121dfbef070e2dc..b330c30eeb668dfbbb6e42a401b6e93045ee50f5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -174,7 +174,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleGenerateToken(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* token) override; + + Status HandleSort(HloInstruction* sort) override; // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 72eb9930e92c340ab9f42cd563c27507623b2ba7..42770d848a83b2e27b87bc963d259e2b7af664a4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -206,6 +206,15 @@ TEST_P(HloEvaluatorTest, DoesOr) { std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise or with 2 operands. +TEST_P(HloEvaluatorTest, DoesXor) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{3, 4}, {-104, 0}}); + TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. TEST_P(HloEvaluatorTest, DoesMultiply) { auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index bc7340aa036ecb322b37fbe4c72fa43485b2f57d..1136178e90b216960543c194348dfbceb964ca95 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -610,12 +610,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); + return InvalidArgument("Unsupported type for And"); } template < @@ -644,12 +639,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); + return InvalidArgument("Unsupported type for Or"); } template < @@ -663,6 +653,35 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleOr(or_); } + template ::value>::type* = + nullptr> + Status HandleXor(HloInstruction* xor_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[xor_], + ElementWiseBinaryOp(xor_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el ^ rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleXor(HloInstruction* xor_) { + return InvalidArgument("Unsupported type for Xor"); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleXor(HloInstruction* xor_) { + return InvalidArgument("Unsupported type for Xor"); + } + + Status HandleXor(HloInstruction* xor_) override { + return HandleXor(xor_); + } + template ::value && @@ -1006,83 +1025,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); - std::vector lhs_non_contracting_dims; + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + + // result_index_locations[i] contains one or two pointers to the locations + // in lhs_index or rhs_index where the i'th result index should go. + tensorflow::gtl::InlinedVector, kInlineRank> + result_index_locations; + result_index_locations.reserve(lhs_rank + rhs_rank - 2); + + // The first components in the output shape are the LHS and RHS batch + // dimensions: + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) { + result_index_locations.push_back( + {&lhs_index[dnums.lhs_batch_dimensions(i)], + &rhs_index[dnums.rhs_batch_dimensions(i)]}); + } + + // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); + if (i != lhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + result_index_locations.push_back({&lhs_index[i], nullptr}); } } - - std::vector rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); + if (i != rhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + result_index_locations.push_back({&rhs_index[i], nullptr}); } } - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); auto result = MakeUnique(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + for (int64 i = 0; i < result_index.size(); i++) { + *result_index_locations[i].first = result_index[i]; + if (result_index_locations[i].second) { + *result_index_locations[i].second = result_index[i]; + } } // Accumulates resulting product along the contracted dimension. @@ -1378,6 +1361,88 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleSort(HloInstruction* sort) { + auto keys = sort->operand(0); + TF_RET_CHECK(ShapeUtil::Rank(keys->shape()) == 1) + << "Sort is only supported for R1 shapes"; + + const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + const auto& keys_data = keys_literal.data(); + + if (sort->operand_count() == 1) { + std::vector result_data(keys_data.begin(), keys_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess(a, b); + }); + auto result_literal = MakeUnique(sort->shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + parent_->evaluated_[sort] = std::move(result_literal); + } else { + CHECK_EQ(sort->operand_count(), 2); + auto values = sort->operand(1); + if (values->shape().element_type() != + primitive_util::NativeToPrimitiveType()) { + return InvalidArgument( + "Evaluator requires value and key types for Sort to match"); + } + + // We need to sort and array of keys and an array of values, where the + // sorted order of the values is determined by the keys. The simplest(?) + // way to do this is to go to an array-of-pairs representation, sort the + // array using the keys, and then go back to pair-of-arrays. + const Literal& values_literal = parent_->GetEvaluatedLiteralFor(values); + VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); + const auto& values_data = values_literal.data(); + using kv_pair = std::pair; + std::vector key_value_vector; + CHECK_EQ(keys_data.size(), values_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back( + std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); + std::vector result_keys, result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + auto result_keys_literal = MakeUnique(keys->shape()); + result_keys_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_keys)); + auto result_values_literal = MakeUnique(values->shape()); + result_values_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_values)); + auto result_tuple = Literal::MakeTuple( + {result_keys_literal.get(), result_values_literal.get()}); + VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + parent_->evaluated_[sort] = std::move(result_tuple); + } + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleSort(HloInstruction* sort) { + return InvalidArgument("Unsupported type for Sort"); + } + + Status HandleSort(HloInstruction* sort) override { + return HandleSort(sort); + } + Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); @@ -2118,6 +2183,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return rhs_unsigned >= lhs_size_unsigned; } + // It's UB to use std::sort with std::less, because of NaNs. Define + // "safe" less functions which are actually strict weak orders. + template ::value>::type* = + nullptr> + static bool SafeLess(const NativeT& a, const NativeT& b) { + return a < b; + } + + template ::value || + std::is_same::value>::type* = nullptr> + static bool SafeLess(const NativeT& a, const NativeT& b) { + if (std::isnan(b)) { + return !std::isnan(a); + } else { + return a < b; + } + } + + template ::value>::type* = nullptr> + static bool SafeLess(const NativeT& a, const NativeT& b) { + if (Eigen::half_impl::isnan(b)) { + return !Eigen::half_impl::isnan(a); + } else { + return a < b; + } + } + HloEvaluator* parent_; }; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index ab224021c54fb3f5c5b69d0b633a080c304d5edd..8856723f67cf22c44e5ee482777a6a0908d1725d 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -960,6 +960,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: @@ -983,7 +984,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a1af8939e7411bbda9ed74e6bfb5c24fc6a0f940..e0e3d301be957df0f0cdf4e01e00cc46d2760d89 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -65,31 +64,47 @@ StatusOr> HloInstruction::CreateFromProto( const auto operands = [&instruction_map, &proto](int index) { return instruction_map.at(proto.operand_ids(index)); }; + const auto all_operands = [&instruction_map, &proto]() { + std::vector result(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + result.begin(), [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + return result; + }; const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: - CHECK_EQ(proto.operand_ids_size(), 3); + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "BatchNormTraining instruction should have 3 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormTraining( proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: - CHECK_EQ(proto.operand_ids_size(), 5); + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormInference instruction should have 5 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormInference( proto.shape(), operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: - CHECK_EQ(proto.operand_ids_size(), 5); + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormGrad instruction should have 5 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kFft: { - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Fft instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), @@ -97,75 +112,85 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kSend: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Send instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateSend(operands(0), proto.channel_id()); break; case HloOpcode::kSendDone: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "SendDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateSendDone(operands(0)); break; case HloOpcode::kRecv: - CHECK_EQ(proto.operand_ids_size(), 0); + TF_RET_CHECK(proto.operand_ids_size() == 0) + << "Recv instruction should have 0 operand but sees " + << proto.operand_ids_size(); instruction = CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); break; case HloOpcode::kRecvDone: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "RecvDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateRecvDone(operands(0)); break; case HloOpcode::kReverse: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Reverse instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReverse(proto.shape(), operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; - case HloOpcode::kConcatenate: { - CHECK_EQ(proto.dimensions_size(), 1); - std::vector concat_operands(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - concat_operands.begin(), - [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateConcatenate(proto.shape(), concat_operands, - proto.dimensions(0)); + case HloOpcode::kConcatenate: + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Concatenate instruction should have 1 dimension but sees " + << proto.dimensions_size(); + instruction = + CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); break; - } case HloOpcode::kReduce: - CHECK_EQ(proto.operand_ids_size(), 2); - CHECK_EQ(proto.called_computation_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Reduce instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Reduce instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); instruction = CreateReduce(proto.shape(), operands(0), operands(1), std::vector(proto.dimensions().begin(), proto.dimensions().end()), computations(0)); break; case HloOpcode::kTranspose: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Transpose instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateTranspose(proto.shape(), operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; case HloOpcode::kBroadcast: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Broadcast instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateBroadcast(proto.shape(), operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; - case HloOpcode::kMap: { - CHECK_EQ(proto.called_computation_ids_size(), 1); - std::vector map_operands(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - map_operands.begin(), - [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateMap(proto.shape(), map_operands, computations(0)); + case HloOpcode::kMap: + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Map instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateMap(proto.shape(), all_operands(), computations(0)); break; - } case HloOpcode::kSlice: { - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Slice instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector slice_starts, slice_limits, slice_strides; for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { @@ -192,7 +217,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Trace instruction should have 1 operand but sees " << proto.operand_ids_size(); - CHECK(proto.has_literal()); + TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); @@ -208,37 +233,28 @@ StatusOr> HloInstruction::CreateFromProto( // Find the fused computation and set its fusion instruction. TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Expect 1 called computation for fusion instruction, but sees " + << "Expect 1 called computation for fusion instruction but sees " << proto.called_computation_ids_size(); const int64 fusion_id = proto.called_computation_ids(0); auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; - std::vector fusion_operands(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - fusion_operands.begin(), - [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateFusion(proto.shape(), fusion_kind, fusion_operands, + instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), fused_computation); break; } - case HloOpcode::kRng: { - std::vector rng_parms(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - rng_parms.begin(), [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateRng(proto.shape(), proto.distribution(), rng_parms); + case HloOpcode::kRng: + instruction = + CreateRng(proto.shape(), proto.distribution(), all_operands()); break; - } case HloOpcode::kParameter: instruction = CreateParameter(proto.parameter_number(), proto.shape(), proto.name()); break; case HloOpcode::kGetTupleElement: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "GetTupleElement instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateGetTupleElement(proto.shape(), operands(0), proto.tuple_index()); break; @@ -247,26 +263,113 @@ StatusOr> HloInstruction::CreateFromProto( CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; - case HloOpcode::kInfeed: - instruction = CreateInfeed(proto.shape(), proto.infeed_config()); - break; + case HloOpcode::kInfeed: { + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + if (proto.operand_ids_size() == 0) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = CreateInfeed(data_shape, proto.infeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 2); + instruction = + CreateInfeed(data_shape, operands(0), proto.infeed_config()); + } + } break; case HloOpcode::kOutfeed: - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - proto.outfeed_config()); + if (proto.operand_ids_size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + proto.outfeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 2); + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + operands(1), proto.outfeed_config()); + } break; case HloOpcode::kCrossReplicaSum: { - CHECK_EQ(proto.called_computation_ids_size(), 1); - std::vector all_operands(proto.operand_ids_size()); - c_transform(proto.operand_ids(), all_operands.begin(), - [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "CrossReplicaSum should have 1 called computation but sees " + << proto.called_computation_ids_size(); + tensorflow::gtl::optional all_reduce_id; + if (proto.all_reduce_id() > 0) { + all_reduce_id = proto.all_reduce_id(); + } instruction = CreateCrossReplicaSum( - proto.shape(), all_operands, computations(0), + proto.shape(), all_operands(), computations(0), /*replica_group_ids=*/ std::vector(proto.replica_group_ids().begin(), proto.replica_group_ids().end()), - /*barrier=*/""); + /*barrier=*/proto.cross_replica_sum_barrier(), + /*all_reduce_id=*/all_reduce_id); + break; + } + case HloOpcode::kConvolution: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Convolution instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_window()); + TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + instruction = + CreateConvolve(proto.shape(), operands(0), operands(1), + proto.window(), proto.convolution_dimension_numbers()); + break; + case HloOpcode::kReduceWindow: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "ReduceWindow instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "ReduceWindow should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), + proto.window(), computations(0)); + break; + case HloOpcode::kSelectAndScatter: + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "SelectAndScatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 2) + << "SelectAndScatter should have 2 called computations but sees " + << proto.called_computation_ids_size(); + instruction = CreateSelectAndScatter( + proto.shape(), operands(0), computations(0), proto.window(), + operands(1), operands(2), computations(1)); + break; + case HloOpcode::kCustomCall: + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target()); + if (proto.has_window()) { + static_cast(instruction.get()) + ->set_window(proto.window()); + } + if (proto.has_convolution_dimension_numbers()) { + static_cast(instruction.get()) + ->set_convolution_dimension_numbers( + proto.convolution_dimension_numbers()); + } + break; + case HloOpcode::kHostCompute: + instruction = + CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(), + proto.cost_estimate_ns()); + break; + case HloOpcode::kPad: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Pad instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_padding_config()); + instruction = CreatePad(proto.shape(), operands(0), operands(1), + proto.padding_config()); + break; + case HloOpcode::kDynamicSlice: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "DynamicSlice instruction should have 2 operands but sees " + << proto.operand_ids_size(); + std::vector slice_sizes(proto.dynamic_slice_sizes_size()); + c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), + slice_sizes); break; } default: { @@ -299,28 +402,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - if (proto.has_window()) { - instruction->window_ = MakeUnique(proto.window()); - } - if (proto.has_convolution_dimension_numbers()) { - instruction->convolution_dimension_numbers_ = - MakeUnique( - proto.convolution_dimension_numbers()); - } if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { - instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); - } - if (proto.has_padding_config()) { - instruction->padding_config_ = - MakeUnique(proto.padding_config()); - } - instruction->custom_call_target_ = proto.custom_call_target(); - if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, HloSharding::FromProto(proto.sharding())); @@ -334,10 +420,6 @@ StatusOr> HloInstruction::CreateFromProto( for (int64 bound : proto.gather_window_bounds()) { instruction->gather_window_bounds_.push_back(bound); } - - instruction->channel_name_ = proto.channel_name(); - instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); - return std::move(instruction); } @@ -407,7 +489,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: break; default: @@ -442,6 +523,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -478,30 +560,16 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands) { - return MakeUnique(shape, operands, map_computation, - static_operands); + HloComputation* map_computation) { + return MakeUnique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); - if (window_util::HasBaseDilation(window)) { - instruction->name_ = instruction->name() + "-base-dilated"; - } - if (window_util::HasWindowDilation(window)) { - instruction->name_ = instruction->name() + "-window-dilated"; - } - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->window_ = MakeUnique(window); - instruction->convolution_dimension_numbers_ = - MakeUnique(dimension_numbers); - return instruction; + return MakeUnique(shape, lhs, rhs, window, + dimension_numbers); } /* static */ std::unique_ptr HloInstruction::CreateFft( @@ -557,14 +625,28 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& shape, const string& config) { - return MakeUnique(shape, config); + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config) { + return MakeUnique(infeed_shape, token_operand, config); +} + +/* static */ std::unique_ptr HloInstruction::CreateInfeed( + const Shape& infeed_shape, const string& config) { + return MakeUnique(infeed_shape, config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { + return MakeUnique(outfeed_shape, operand, + token_operand, outfeed_config); +} + +/* static */ std::unique_ptr HloInstruction::CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) { - return MakeUnique(shape, operand, outfeed_config); + return MakeUnique(outfeed_shape, operand, + outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( @@ -599,11 +681,10 @@ HloInstruction::CreateCrossReplicaSum( return MakeUnique(shape, operand, dimensions); } -/* static */ std::unique_ptr -HloInstruction::CreateGenerateToken( +/* static */ std::unique_ptr HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice operands) { - auto instruction = WrapUnique(new HloInstruction( - HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + auto instruction = WrapUnique( + new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -650,13 +731,8 @@ HloInstruction::CreateGenerateToken( /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(start_indices); - instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(), - slice_sizes.end()); - return instruction; + return MakeUnique(shape, operand, start_indices, + slice_sizes); } /* static */ std::unique_ptr @@ -705,13 +781,8 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(init_value); - instruction->called_computations_.push_back(reduce_computation); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique(shape, operand, init_value, + window, reduce_computation); } /* static */ std::unique_ptr @@ -749,16 +820,8 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(source); - instruction->AppendOperand(init_value); - // Select comes before scatter in the vector. - instruction->called_computations_.push_back(select); - instruction->called_computations_.push_back(scatter); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique( + shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( @@ -823,11 +886,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(padding_value); - instruction->padding_config_ = MakeUnique(padding_config); - return instruction; + return MakeUnique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -847,6 +907,16 @@ HloInstruction::CreateBroadcastSequence( return MakeUnique(shape, operand, dimensions); } +/* static */ std::unique_ptr HloInstruction::CreateSort( + const Shape& shape, HloInstruction* keys, HloInstruction* values) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSort, shape)); + instruction->AppendOperand(keys); + if (values) { + instruction->AppendOperand(values); + } + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { return MakeUnique(shape, fusion_kind, fused_root); @@ -924,26 +994,15 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->custom_call_target_ = std::string(custom_call_target); - return instruction; + return MakeUnique(shape, operands, + custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateHostCompute( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->channel_name_ = std::string(channel_name); - instruction->cost_estimate_ns_ = cost_estimate_ns; - return instruction; + return MakeUnique(shape, operands, channel_name, + cost_estimate_ns); } /* static */ std::unique_ptr HloInstruction::CreateTuple( @@ -1043,6 +1102,13 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCrossReplicaSum: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kHostCompute: + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1065,7 +1131,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1089,6 +1154,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -1106,21 +1172,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; - case HloOpcode::kCustomCall: - clone = CreateCustomCall(shape, new_operands, custom_call_target_); - if (window_ != nullptr) { - clone->window_ = MakeUnique(*window_); - } - if (convolution_dimension_numbers_ != nullptr) { - clone->convolution_dimension_numbers_ = - MakeUnique( - *convolution_dimension_numbers_); - } - break; - case HloOpcode::kHostCompute: - clone = CreateHostCompute(shape, new_operands, channel_name_, - cost_estimate_ns_); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1129,40 +1180,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kConvolution: - CHECK_EQ(new_operands.size(), 2); - clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, - *convolution_dimension_numbers_); - break; case HloOpcode::kDot: CHECK_EQ(new_operands.size(), 2); clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kPad: - CHECK_EQ(new_operands.size(), 2); - clone = - CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); - break; - case HloOpcode::kReduceWindow: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], - *window_, to_apply()); - break; - case HloOpcode::kSelectAndScatter: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateSelectAndScatter(shape, new_operands[0], select(), *window_, - new_operands[1], new_operands[2], scatter()); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kDynamicSlice: - clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], - dynamic_slice_sizes_); - break; case HloOpcode::kDynamicUpdateSlice: CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], @@ -1194,8 +1220,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); break; - case HloOpcode::kGenerateToken: - clone = CreateGenerateToken(new_operands); + case HloOpcode::kAfterAll: + clone = CreateAfterAll(new_operands); + break; + case HloOpcode::kSort: + CHECK(new_operands.size() == 1 || new_operands.size() == 2) + << "Too many operands for sort: " << new_operands.size(); + HloInstruction* keys = new_operands[0]; + HloInstruction* values = + new_operands.size() == 2 ? new_operands[1] : nullptr; + clone = CreateSort(shape, keys, values); break; } SetupDerivedInstruction(clone.get()); @@ -1380,6 +1414,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { operand->AddUser(this); } +void HloInstruction::RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice ascending_indices) { + if (ascending_indices.empty()) { + return; + } + int next_index = 0; + int removed_count = 0; + for (int to_remove : ascending_indices) { + while (next_index < to_remove) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_LT(to_remove, operands_.size()); + ++removed_count; + ++next_index; + } + while (next_index < operands_.size()) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_EQ(removed_count, ascending_indices.size()); + operands_.resize(operands_.size() - removed_count); +} + void HloInstruction::AddUser(HloInstruction* user) { if (!ContainsKey(user_set_, user)) { user_set_.insert(user); @@ -1417,7 +1475,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kDivide: - case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: @@ -1433,6 +1490,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1449,6 +1507,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: + case HloOpcode::kSort: case HloOpcode::kSin: case HloOpcode::kSubtract: case HloOpcode::kTanh: @@ -1458,15 +1517,9 @@ bool HloInstruction::IdenticalSlowPath( // These opcodes have complex or special behavior so just return false. case HloOpcode::kDomain: case HloOpcode::kWhile: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: return false; - // Convolution has a window and dimensions. - case HloOpcode::kConvolution: - return protobuf_util::ProtobufEquals(window(), other.window()) && - protobuf_util::ProtobufEquals( - convolution_dimension_numbers(), - other.convolution_dimension_numbers()); // Check dot dimension numbers. case HloOpcode::kDot: return protobuf_util::ProtobufEquals(dot_dimension_numbers(), @@ -1477,51 +1530,13 @@ bool HloInstruction::IdenticalSlowPath( other.gather_dimension_numbers()) && gather_window_bounds() == other.gather_window_bounds(); - case HloOpcode::kReduceWindow: - return eq_computations(to_apply(), other.to_apply()) && - protobuf_util::ProtobufEquals(window(), other.window()); - - // SelectAndScatter is determined by both select and scatter - // computation as well as the window configuration. - case HloOpcode::kSelectAndScatter: - return eq_computations(select(), other.select()) && - eq_computations(scatter(), other.scatter()) && - protobuf_util::ProtobufEquals(window(), other.window()); - // Remaining instructions with special values. - case HloOpcode::kPad: - return protobuf_util::ProtobufEquals(padding_config(), - other.padding_config()); case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); - case HloOpcode::kCrossReplicaSum: - return replica_group_ids() == other.replica_group_ids() && - cross_replica_sum_barrier() == other.cross_replica_sum_barrier() && - eq_computations(to_apply(), other.to_apply()); - case HloOpcode::kCustomCall: - if ((window_ == nullptr) != (other.window_ == nullptr) || - (window_ != nullptr && - !protobuf_util::ProtobufEquals(window(), other.window()))) { - return false; - } - if ((convolution_dimension_numbers_ == nullptr) != - (other.convolution_dimension_numbers_ == nullptr) || - (convolution_dimension_numbers_ != nullptr && - !protobuf_util::ProtobufEquals( - convolution_dimension_numbers(), - other.convolution_dimension_numbers()))) { - return false; - } - return custom_call_target_ == other.custom_call_target_; case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); - // These opcodes are not yet supported. - case HloOpcode::kSort: - case HloOpcode::kHostCompute: - return false; - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: @@ -1548,6 +1563,14 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReducePrecision: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kHostCompute: + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1582,6 +1605,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast(user)->DeduplicateFusionOperands()); + } return Status::OK(); } @@ -1590,6 +1617,10 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); + if (old_operand == new_operand) { + return Status::OK(); + } + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), new_operand->shape())) << old_operand->shape().ShortDebugString() << " is not compatible with " @@ -1620,6 +1651,10 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast(user)->DeduplicateFusionOperands()); + } } } users_.clear(); @@ -1668,11 +1703,6 @@ void HloInstruction::set_to_apply(HloComputation* computation) { } } -const string& HloInstruction::custom_call_target() const { - CHECK_EQ(opcode_, HloOpcode::kCustomCall); - return custom_call_target_; -} - HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); return called_computations_[kConditionComputationIndex]; @@ -1699,32 +1729,6 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } -HloComputation* HloInstruction::select() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kSelectComputationIndex]; -} - -HloComputation* HloInstruction::scatter() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kScatterComputationIndex]; -} - -void HloInstruction::set_select(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kSelectComputationIndex] = computation; -} - -void HloInstruction::set_scatter(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kScatterComputationIndex] = computation; -} - HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -1820,6 +1824,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -1832,6 +1837,9 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kClamp: return true; + case HloOpcode::kDynamicUpdateSlice: + return operand_idx.has_value() && operand_idx.value() == 0; + default: return false; } @@ -1925,24 +1933,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector extra = ExtraAttributesToStringImpl(options); - if (window_ != nullptr && window_->dimensions_size() != 0) { - extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); - } - if (padding_config_ != nullptr) { - extra.push_back( - StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); - } - if (opcode() == HloOpcode::kDynamicSlice) { - extra.push_back( - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); - } - - if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(StrCat( - "dim_labels=", - ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); - } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } @@ -2041,14 +2032,6 @@ std::vector HloInstruction::ExtraAttributesToString( ", exit=", user_side_metadata_->ToString(), "}")); } - // By contract, we print the custom call target even if - // options.print_subcomputation_mode() == kOff, because the call target is not - // an HloComputation. - if (opcode() == HloOpcode::kCustomCall) { - extra.push_back( - StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); - } - return extra; } @@ -2085,13 +2068,6 @@ HloInstructionProto HloInstruction::ToProto() const { } } - if (window_ != nullptr) { - *proto.mutable_window() = *window_; - } - if (convolution_dimension_numbers_ != nullptr) { - *proto.mutable_convolution_dimension_numbers() = - *convolution_dimension_numbers_; - } if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } @@ -2104,21 +2080,10 @@ HloInstructionProto HloInstruction::ToProto() const { } } - for (int64 slice_size : dynamic_slice_sizes_) { - proto.add_dynamic_slice_sizes(slice_size); - } - if (padding_config_ != nullptr) { - *proto.mutable_padding_config() = *padding_config_; - } - proto.set_custom_call_target(custom_call_target_); - if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } - proto.set_channel_name(channel_name_); - proto.set_cost_estimate_ns(cost_estimate_ns_); - return proto; } @@ -2128,35 +2093,6 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - if (opcode() == HloOpcode::kConvolution) { - string category = "convolution"; - if (window_util::HasBaseDilation(window())) { - category += " base-dilated"; - } - if (window_util::HasWindowDilation(window())) { - category += " window-dilated"; - } - return category; - } - - // Give transpose-dot and backwards-conv fusions the categories "dot" and - // "convolution" so they match the categories of proper kDot and kConvolution - // ops. These fusion categories are really just a way of expressing a - // particular kind of dot or conv, so they should have the same category as a - // vanilla dot/conv. - if (opcode() == HloOpcode::kFusion) { - switch (fusion_kind()) { - case FusionKind::kLoop: - return "loop fusion"; - case FusionKind::kInput: - return "input fusion"; - case FusionKind::kOutput: - return "output fusion"; - case FusionKind::kCustom: - return "custom fusion"; - } - } - if (IsElementwise()) { return "non-fusion elementwise"; } @@ -2242,6 +2178,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAnd(this); case HloOpcode::kOr: return visitor->HandleOr(this); + case HloOpcode::kXor: + return visitor->HandleXor(this); case HloOpcode::kShiftLeft: return visitor->HandleShiftLeft(this); case HloOpcode::kShiftRightArithmetic: @@ -2366,8 +2304,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleGather(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); - case HloOpcode::kGenerateToken: - return visitor->HandleGenerateToken(this); + case HloOpcode::kAfterAll: + return visitor->HandleAfterAll(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3175,4 +3113,61 @@ tensorflow::gtl::optional HloInstruction::all_reduce_id() const { return Cast(this)->all_reduce_id(); } +const ConvolutionDimensionNumbers& +HloInstruction::convolution_dimension_numbers() const { + if (auto convolution = DynCast(this)) { + return convolution->convolution_dimension_numbers(); + } + if (auto custom_call = DynCast(this)) { + return custom_call->convolution_dimension_numbers(); + } + LOG(FATAL) << "Unimplemented method."; +} + +void HloInstruction::set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + if (auto convolution = DynCast(this)) { + convolution->set_convolution_dimension_numbers(dnums); + } else if (auto custom_call = DynCast(this)) { + custom_call->set_convolution_dimension_numbers(dnums); + } else { + LOG(FATAL) << "Unimplemented method."; + } +} + +HloComputation* HloInstruction::select() const { + return Cast(this)->select(); +} + +HloComputation* HloInstruction::scatter() const { + return Cast(this)->scatter(); +} + +void HloInstruction::set_select(HloComputation* computation) { + return Cast(this)->set_select(computation); +} + +void HloInstruction::set_scatter(HloComputation* computation) { + return Cast(this)->set_scatter(computation); +} + +const string& HloInstruction::custom_call_target() const { + return Cast(this)->custom_call_target(); +} + +const string& HloInstruction::channel_name() const { + return Cast(this)->channel_name(); +} + +const PaddingConfig& HloInstruction::padding_config() const { + return Cast(this)->padding_config(); +} + +int64 HloInstruction::slice_sizes(int64 dimension) const { + return Cast(this)->slice_sizes(dimension); +} + +const std::vector& HloInstruction::dynamic_slice_sizes() const { + return Cast(this)->dynamic_slice_sizes(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8a0ffc21cd49270316619022a243bf8e16ed1d98..04590721271f17e3731ff7a01c435c439d47058b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -389,11 +389,10 @@ class HloInstruction { // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, - // at a given index) with the same `static_operands`. + // at a given index) static std::unique_ptr CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands = {}); + HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. @@ -459,13 +458,29 @@ class HloInstruction { const Shape& shape, HloInstruction* operand); // Creates an infeed instruction, which reads data of the given shape from the - // Infeed interface of the device. - static std::unique_ptr CreateInfeed(const Shape& shape, + // Infeed interface of the device. infeed_shape is the shape of the data + // received from the infeed *not* the shape of the infeed instruction which + // is a tuple containing the infeed_shape and the TOKEN. + static std::unique_ptr CreateInfeed( + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of infeed are + // converted to take tokens. + static std::unique_ptr CreateInfeed(const Shape& infeed_shape, const string& config); - // Creates an outfeed instruction, which outputs data. + // Creates an outfeed instruction, which outputs data. outfeed_shape is the + // shape of the data being outfed *not* the shape of the outfeed instruction + // which is a TOKEN. static std::unique_ptr CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of infeed are + // converted to take tokens. + static std::unique_ptr CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); // Creates an asynchronous send instruction with the given channel id, which @@ -596,6 +611,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a sort op, with a keys operand, and an optional values operand. + static std::unique_ptr CreateSort( + const Shape& shape, HloInstruction* keys, + HloInstruction* values = nullptr); + // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 @@ -665,9 +685,9 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); - // Creates a token instruction used for joining or creating token types which - // thread through side-effecting operations. - static std::unique_ptr CreateGenerateToken( + // Creates a token instruction used for joining or creating new values of + // token type which thread through side-effecting operations. + static std::unique_ptr CreateAfterAll( tensorflow::gtl::ArraySlice operands); // Creates an instance of GatherDimensionNumbers. @@ -811,9 +831,15 @@ class HloInstruction { // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. + // + // If user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); // Replaces the specified operand with new_operand. + // + // This function does NOT remove duplicated operands even if this instruction + // is a fusion, so that the existing operand numbers do not change. Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If @@ -822,6 +848,9 @@ class HloInstruction { // // If this instruction is the root of its computation, sets the computation's // root to new_producer. + // + // If a user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); // Performs a postorder DFS visit using this node as the root. If @@ -896,10 +925,6 @@ class HloInstruction { HloComputation* to_apply() const; void set_to_apply(HloComputation* to_apply); - // Returns the custom_call_target for CustomCall. - // Precondition: opcode() == HloOpcode::kCustomCall - const string& custom_call_target() const; - // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // @@ -909,15 +934,6 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); - // Gets/sets the select or scatter HloComputation for SelectAndScatter. The - // setters should only be called by HloModule or HloComputation methods. - // - // Precondition: opcode() == HloOpcode::kSelectAndScatter. - HloComputation* select() const; - HloComputation* scatter() const; - void set_select(HloComputation* select); - void set_scatter(HloComputation* scatter); - // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -959,7 +975,7 @@ class HloInstruction { // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". - string ToCategory() const; + virtual string ToCategory() const; // Returns a logging instruction, if the output of this instruction is logged. // @@ -967,12 +983,6 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel name associated with the instruction. The name is - // used to identify host Send/Recv operations. - // - // Precondition: opcode() == HloOpcode::kHostCompute - string channel_name() const { return channel_name_; } - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; @@ -1052,56 +1062,6 @@ class HloInstruction { copy_elision_allowed_ = value; } - // Returns the size of the slice in the given dimension for a dynamic - // slice node. - // - // Precondition: opcode() == HloOpcode::kDynamicSlice - int64 slice_sizes(int64 dimension) const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_[dimension]; - } - const std::vector& dynamic_slice_sizes() const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_; - } - - // Returns data on the window in a windowed operation such as - // convolution. - const Window& window() const { - CHECK(window_ != nullptr); - return *window_; - } - - // Sets the window data in a windowed operation such as convolution. - void set_window(const Window& window) { - window_ = MakeUnique(window); - } - - // Returns the padding configuration for a pad node. - // - // Precondition: opcode() == HloOpcode::kPad - const PaddingConfig& padding_config() const { - CHECK(padding_config_ != nullptr); - return *padding_config_; - } - - // Returns data on the dimension numbers used for a convolution operation, - // which may be a kConvolution instruction or a kCustomCall that implements a - // convolution. - const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { - CHECK(convolution_dimension_numbers_ != nullptr); - return *convolution_dimension_numbers_; - } - - // Sets the convolution dimension numbers on this instruction. In general you - // shouldn't need to call this; instead, specify the convolution dimension - // numbers when you create the instruction. - void set_convolution_dimension_numbers( - const ConvolutionDimensionNumbers& dnums) { - convolution_dimension_numbers_ = - MakeUnique(dnums); - } - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1441,6 +1401,55 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::all_reduce_id. tensorflow::gtl::optional all_reduce_id() const; + + // Returns data on the window in a windowed operation such as + // convolution. + virtual const Window& window() const { + LOG(FATAL) << "Unimplemented method."; + } + + // Sets the window data in a windowed operation such as convolution. + virtual void set_window(const Window& window) { + LOG(FATAL) << "Unimplemented method."; + } + + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; + + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums); + + // Delegates to HloSelectAndScatterInstruction::select. + HloComputation* select() const; + + // Delegates to HloSelectAndScatterInstruction::scatter. + HloComputation* scatter() const; + + // Delegates to HloSelectAndScatterInstruction::set_select. + void set_select(HloComputation* computation); + + // Delegates to HloSelectAndScatterInstruction::set_scatter. + void set_scatter(HloComputation* computation); + + // Delegates to HloCustomCallInstruction::custom_call_target. + const string& custom_call_target() const; + + // Delegates to HloHostComputeInstruction::channel_name. + const string& channel_name() const; + + // Delegates to HloPadInstruction::padding_config. + const PaddingConfig& padding_config() const; + + // Delegates to HloDynamicSliceInstruction::slice_sizes. + int64 slice_sizes(int64 dimension) const; + + // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. + const std::vector& dynamic_slice_sizes() const; // Old methods kept for smooth subclassing transition END. protected: @@ -1460,12 +1469,35 @@ class HloInstruction { operands_.erase(operands_.begin() + index); } + // Removes a list of operands with the given indices in ascending order. + void RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice ascending_indices); + void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); } void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } + void set_called_computation(int index, HloComputation* computation) { + called_computations_[index] = computation; + } + // Indices of computations in called_computations_ for instructions which call + // multiple computations. + enum { + // kWhile computations. + kBodyComputationIndex = 0, + kConditionComputationIndex = 1, + + // kSelectAndScatter computations. + kSelectComputationIndex = 0, + kScatterComputationIndex = 1, + + // kConditional computations. + kTrueComputationIndex = 0, + kFalseComputationIndex = 1, + }; + private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( @@ -1558,12 +1590,6 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Describes the window in a windowed operation such as convolution. - std::unique_ptr window_; - - // Describes the dimension numbers used for a convolution. - std::unique_ptr convolution_dimension_numbers_; - // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; @@ -1573,14 +1599,6 @@ class HloInstruction { // Used to tag kCopy instructions that are eligible for copy elision. bool copy_elision_allowed_ = true; - // Describes the [start, start + size) range size for a dynamic slice - // ('start' is specified dynamically in the second operand of the operation). - std::vector dynamic_slice_sizes_; - - // The padding configuration that describes the edge padding and interior - // padding of this pad instruction. Only set for pad instructions. - std::unique_ptr padding_config_; - // The sharding, if one exists. std::unique_ptr sharding_; @@ -1588,34 +1606,9 @@ class HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; - // Name of a global symbol to call, only present for kCustomCall. - string custom_call_target_; - - // Name to use for host send/recv channels, only present for kHostCompute. - string channel_name_; - - // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_ = 0; - // Computations called by this instruction. std::vector called_computations_; - // Indices of computations in called_computations_ for instructions which call - // multiple computations. - enum { - // kWhile computations. - kBodyComputationIndex = 0, - kConditionComputationIndex = 1, - - // kSelectAndScatter computations. - kSelectComputationIndex = 0, - kScatterComputationIndex = 1, - - // kConditional computations. - kTrueComputationIndex = 0, - kFalseComputationIndex = 1, - }; - // A trace instruction that consumes this instruction. // // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8ee24f9d92f61453a19a019c6e9c22ce37be1589..d8ca99dfd12ef95ab5e1ea61093d8bf3ea97a5e2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -716,10 +716,11 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { }))); auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); + auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto outfeed10 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape10, constant, "")); + HloInstruction::CreateOutfeed(shape10, constant, token, "")); auto outfeed01 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape01, constant, "")); + HloInstruction::CreateOutfeed(shape01, constant, token, "")); auto clone01 = builder.AddInstruction(outfeed01->Clone()); auto clone10 = builder.AddInstruction(outfeed10->Clone()); @@ -763,12 +764,12 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); - auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {constant}, computation_x, /*static_operands=*/{})); - auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); - auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); + auto map_1_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {constant}, computation_x)); + auto map_2_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x)); + auto map_3_y = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y)); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( @@ -1170,6 +1171,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); } +TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { + // Fused expression: + // + // x y + // | | + // | transpose + // \ / + // dot + const Shape s = ShapeUtil::MakeShape(F32, {10, 10}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok()); + + EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y)); + EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1); +} + TEST_F(HloInstructionTest, FusionEquality) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 5871a6605fed24865d8cbe7e1cee5a4d5fadb357..e2f43f581091af49a4bdb96c8c42eb52035ce6fd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -280,7 +282,7 @@ HloAllReduceInstruction::HloAllReduceInstruction( cross_replica_sum_barrier_(barrier.begin(), barrier.end()), all_reduce_id_(all_reduce_id) { // TODO(b/79737069): Remove the CHECK when supported. - CHECK(!all_reduce_id_.has_value()); + CHECK(!all_reduce_id_); for (auto operand : operands) { AppendOperand(operand); } @@ -292,7 +294,11 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { for (int64 i : replica_group_ids_) { proto.add_replica_group_ids(i); } - // TODO(b/79737069): handle barrier and all_reduce_id. + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); return proto; } @@ -303,7 +309,7 @@ std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } - if (all_reduce_id_.has_value()) { + if (all_reduce_id_) { result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); } return result; @@ -548,10 +554,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( HloMapInstruction::HloMapInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands) + HloComputation* map_computation) : HloInstruction(HloOpcode::kMap, shape) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; for (auto operand : operands) { AppendOperand(operand); } @@ -802,6 +806,19 @@ HloFusionInstruction::HloFusionInstruction( fusion_computation->SetFusionInstruction(this); } +string HloFusionInstruction::ToCategory() const { + switch (fusion_kind()) { + case FusionKind::kLoop: + return "loop fusion"; + case FusionKind::kInput: + return "input fusion"; + case FusionKind::kOutput: + return "output fusion"; + case FusionKind::kCustom: + return "custom fusion"; + } +} + HloInstructionProto HloFusionInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_fusion_kind(xla::ToString(fusion_kind())); @@ -812,10 +829,6 @@ HloInstructionProto HloFusionInstruction::ToProto() const { bool HloFusionInstruction::IsElementwiseImpl( const tensorflow::gtl::optional& operand_idx) const { - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - if (!operand_idx.has_value()) { for (auto* fused : fused_instructions()) { if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { @@ -1196,6 +1209,26 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation); } +Status HloFusionInstruction::DeduplicateFusionOperands() { + tensorflow::gtl::FlatMap operand_indices; + std::vector operands_to_remove; + for (int i = 0; i < operand_count(); ++i) { + auto emplace_result = operand_indices.emplace(operand(i), i); + if (!emplace_result.second) { + TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith( + fused_parameter(emplace_result.first->second))); + operands_to_remove.push_back(i); + } + } + if (operands_to_remove.empty()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + fused_instructions_computation()->RemoveUnusedParameters()); + RemoveOperandsAtAscendingIndices(operands_to_remove); + return Status::OK(); +} + HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) @@ -1351,9 +1384,22 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], exponent_bits(), mantissa_bits()); } -HloInfeedInstruction::HloInfeedInstruction(const Shape& shape, +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, const string& config) - : HloInstruction(HloOpcode::kInfeed, shape), infeed_config_(config) {} + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) { + AppendOperand(token_operand); +} + +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + const string& config) + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) {} HloInstructionProto HloInfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); @@ -1381,19 +1427,37 @@ std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 0); - return MakeUnique(shape, infeed_config()); + if (new_operands.empty()) { + return MakeUnique(infeed_shape(), infeed_config()); + } else { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(infeed_shape(), new_operands[0], + infeed_config()); + } } HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); + AppendOperand(operand); + AppendOperand(token_operand); +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) - : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()), - outfeed_shape_(shape), + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { - CHECK(ShapeUtil::Compatible(operand->shape(), shape)) - << "Outfeed shape " << shape << " must be compatible with operand shape " - << operand->shape(); + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); AppendOperand(operand); } @@ -1424,9 +1488,370 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique(outfeed_shape(), new_operands[0], - outfeed_config()); + if (new_operands.size() == 1) { + return MakeUnique(outfeed_shape(), new_operands[0], + outfeed_config()); + } else { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(outfeed_shape(), new_operands[0], + new_operands[1], outfeed_config()); + } } +HloConvolutionInstruction::HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + : HloInstruction(HloOpcode::kConvolution, shape), + window_(window), + convolution_dimension_numbers_(dimension_numbers) { + if (window_util::HasBaseDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-base-dilated")); + } + if (window_util::HasWindowDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-window-dilated")); + } + AppendOperand(lhs); + AppendOperand(rhs); +} + +string HloConvolutionInstruction::ToCategory() const { + string category = "convolution"; + if (window_util::HasBaseDilation(window())) { + category += " base-dilated"; + } + if (window_util::HasWindowDilation(window())) { + category += " window-dilated"; + } + return category; +} + +HloInstructionProto HloConvolutionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + *proto.mutable_convolution_dimension_numbers() = + convolution_dimension_numbers_; + return proto; +} + +std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( + convolution_dimension_numbers_))); + return extra; +} + +bool HloConvolutionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return protobuf_util::ProtobufEquals(window(), casted_other.window()) && + protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()); +} + +std::unique_ptr +HloConvolutionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], + new_operands[1], window(), + convolution_dimension_numbers_); +} + +HloReduceWindowInstruction::HloReduceWindowInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + const Window& window, HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) { + AppendOperand(operand); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceWindowInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloReduceWindowInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloReduceWindowInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(to_apply(), casted_other.to_apply()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloReduceWindowInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], window(), to_apply()); +} + +HloSelectAndScatterInstruction::HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter) + : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) { + AppendOperand(operand); + AppendOperand(source); + AppendOperand(init_value); + // Select comes before scatter in the vector. + AppendComputation(select); + AppendComputation(scatter); +} + +HloInstructionProto HloSelectAndScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloSelectAndScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloSelectAndScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(select(), casted_other.select()) && + eq_computations(scatter(), casted_other.scatter()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], select(), window(), new_operands[1], + new_operands[2], scatter()); +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), + custom_call_target.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloCustomCallInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (window_ != nullptr) { + *proto.mutable_window() = *window_; + } + if (convolution_dimension_numbers_ != nullptr) { + *proto.mutable_convolution_dimension_numbers() = + *convolution_dimension_numbers_; + } + proto.set_custom_call_target(custom_call_target_); + return proto; +} + +std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_ != nullptr && window_->dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); + } + if (convolution_dimension_numbers_ != nullptr) { + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); + } + // By contract, we print the custom call target even if + // options.print_subcomputation_mode() == kOff, because the call target is not + // an HloComputation. + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + return extra; +} + +bool HloCustomCallInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + if ((window_ == nullptr) != (casted_other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (casted_other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()))) { + return false; + } + return custom_call_target_ == casted_other.custom_call_target_; +} + +std::unique_ptr +HloCustomCallInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + auto cloned = MakeUnique(shape, new_operands, + custom_call_target()); + if (window_ != nullptr) { + cloned->set_window(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); + } + return std::move(cloned); +} + +HloHostComputeInstruction::HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) + : HloInstruction(HloOpcode::kHostCompute, shape), + channel_name_(channel_name.begin(), channel_name.end()), + cost_estimate_ns_(cost_estimate_ns) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloHostComputeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); + return proto; +} + +bool HloHostComputeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr +HloHostComputeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique( + shape, new_operands, channel_name_, cost_estimate_ns_); +} + +HloPadInstruction::HloPadInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) + : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) { + AppendOperand(operand); + AppendOperand(padding_value); +} + +HloInstructionProto HloPadInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_padding_config() = padding_config_; + return proto; +} + +std::vector HloPadInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))}; +} + +bool HloPadInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals(padding_config(), + casted_other.padding_config()); +} + +std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], new_operands[1], + padding_config_); +} + +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes) + : HloInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + AppendOperand(start_indices); +} + +HloInstructionProto HloDynamicSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 slice_size : dynamic_slice_sizes_) { + proto.add_dynamic_slice_sizes(slice_size); + } + return proto; +} + +std::vector HloDynamicSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return { + StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; +} + +bool HloDynamicSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return true; +} + +std::unique_ptr +HloDynamicSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 04df2d860ebe2cd1b7f94a78598295d87b29986f..ec8a42bd3b965f3aad373afd25e76506b2ff3964 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -407,8 +407,7 @@ class HloMapInstruction : public HloInstruction { public: explicit HloMapInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands = {}); + HloComputation* map_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -557,6 +556,7 @@ class HloFusionInstruction : public HloInstruction { tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation); + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -635,6 +635,9 @@ class HloFusionInstruction : public HloInstruction { void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + // If multiple operands are the same instruction, keeps only one of them. + Status DeduplicateFusionOperands(); + private: // Fuses the given instruction into this fusion instruction. When add_output // is false (which is the default), instruction_to_fuse is cloned and the @@ -784,12 +787,25 @@ class HloReducePrecisionInstruction : public HloInstruction { class HloInfeedInstruction : public HloInstruction { public: - explicit HloInfeedInstruction(const Shape& shape, const string& config); + explicit HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, + const string& config); + // TODO(b/80000000): Remove this constructor when all uses of infeed are + // converted to take tokens. + explicit HloInfeedInstruction(const Shape& infeed_shape, + const string& config); // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. string infeed_config() const { return infeed_config_; } void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns the shape of the data received by the infeed. This is not the same + // as the shape of the infeed instruction which produces a tuple containing + // the infeed data shape and a TOKEN. + const Shape& infeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + return ShapeUtil::GetSubshape(shape(), {0}); + } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -812,11 +828,19 @@ class HloInfeedInstruction : public HloInstruction { class HloOutfeedInstruction : public HloInstruction { public: - explicit HloOutfeedInstruction(const Shape& shape, HloInstruction* operand, + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + // TODO(b/80000000): Remove this constructor when all uses of outfeed are + // converted to take tokens. + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + tensorflow::StringPiece outfeed_config); + // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); return outfeed_shape_; } // Returns the config for the Outfeed instruction. @@ -842,6 +866,257 @@ class HloOutfeedInstruction : public HloInstruction { // Outfeed configuration information, only present for kOutfeed. string outfeed_config_; }; + +class HloConvolutionInstruction : public HloInstruction { + public: + explicit HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + return convolution_dimension_numbers_; + } + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = dnums; + } + string ToCategory() const override; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; +}; + +class HloReduceWindowInstruction : public HloInstruction { + public: + explicit HloReduceWindowInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* init_value, + const Window& window, + HloComputation* reduce_computation); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloSelectAndScatterInstruction : public HloInstruction { + public: + explicit HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Gets/sets the select or scatter HloComputation for SelectAndScatter. The + // setters should only be called by HloModule or HloComputation methods. + HloComputation* select() const { + return called_computations()[kSelectComputationIndex]; + } + + HloComputation* scatter() const { + return called_computations()[kScatterComputationIndex]; + } + + void set_select(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kSelectComputationIndex, computation); + } + + void set_scatter(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kScatterComputationIndex, computation); + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloCustomCallInstruction : public HloInstruction { + public: + explicit HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target); + const Window& window() const override { + CHECK(window_ != nullptr); + return *window_; + } + + void set_window(const Window& window) override { + window_ = MakeUnique(window); + } + + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + CHECK(convolution_dimension_numbers_ != nullptr); + return *convolution_dimension_numbers_; + } + + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + const string& custom_call_target() const { return custom_call_target_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target_; + // Describes the window in a windowed operation such as convolution. + std::unique_ptr window_; + // Describes the dimension numbers used for a convolution. + std::unique_ptr convolution_dimension_numbers_; +}; + +class HloHostComputeInstruction : public HloInstruction { + public: + explicit HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Returns the channel name associated with the instruction. The name is + // used to identify host Send/Recv operations. + const string& channel_name() const { return channel_name_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name to use for host send/recv channels. + string channel_name_; + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_ = 0; +}; + +class HloPadInstruction : public HloInstruction { + public: + explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); + // Returns the padding configuration for a pad node. + const PaddingConfig& padding_config() const { return padding_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. + PaddingConfig padding_config_; +}; + +class HloDynamicSliceInstruction : public HloInstruction { + public: + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + // Old methods kept for smooth subclassing transition END. + // Returns the size of the slice in the given dimension for a dynamic + // slice node. + int64 slice_sizes(int64 dimension) const { + return dynamic_slice_sizes_[dimension]; + } + const std::vector& dynamic_slice_sizes() const { + return dynamic_slice_sizes_; + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + std::vector dynamic_slice_sizes_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 8a31a8e617c1fb82201e07d9a3ff1ab9a618206b..b57c940238f0672692e3b65827f43e2f5499502d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -187,7 +187,7 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); -HLO_MATCHER(GenerateToken); +HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); @@ -196,6 +196,7 @@ HLO_MATCHER(Log); HLO_MATCHER(And); HLO_MATCHER(Not); HLO_MATCHER(Or); +HLO_MATCHER(Xor); HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 9c59374b4a9d7e3dbfb99d8a6b30d4230e553658..39bc25ba42c2cb6a9f77e2726405311ba13b3edc 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -58,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_host_entry_computation_layout()) { + if (!config_.has_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -231,14 +231,11 @@ StatusOr> HloModule::CreateFromProto( TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK( - expected_program_shape.parameters_size() == - module_config.device_entry_computation_layout().parameter_count()); + TF_RET_CHECK(expected_program_shape.parameters_size() == + module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.device_entry_computation_layout() - .parameter_layout(i) - .shape(); + module_config.entry_computation_layout().parameter_layout(i).shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -248,7 +245,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.device_entry_computation_layout().result_layout().shape(); + module_config.entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -327,7 +324,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_host_entry_computation_layout(); + module_config.mutable_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -335,9 +332,6 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); - *module_config.mutable_device_entry_computation_layout() = - module_config.host_entry_computation_layout(); - return module_config; } @@ -451,7 +445,7 @@ int64 HloModule::instruction_count() const { return n; } -std::list HloModule::MakeComputationPostOrder() const { +std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). @@ -469,7 +463,7 @@ std::list HloModule::MakeComputationPostOrder() const { // order. This prevents duplication as an embedded computation may be called // from two different root computations. std::set added_computations; - std::list post_order; + std::vector post_order; for (auto& computation : computations_) { if (nonroot_computations.count(computation.get()) == 0) { for (HloComputation* embedded_computation : diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 757e65bda286d983d05e5a791aa7dffe97bac945..d2e726a0db63f622cd5092d56b4f746232d04aad 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -105,20 +105,19 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_host_entry_computation_layout() { - return config_.mutable_host_entry_computation_layout(); + // Creates the ComputationLayout which describes the current status of the HLO + // module entry computation. + ComputationLayout compute_computation_layout() const { + return ComputationLayout(entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); } - const ComputationLayout& host_entry_computation_layout() const { - return config_.host_entry_computation_layout(); + ComputationLayout* mutable_entry_computation_layout() { + return config_.mutable_entry_computation_layout(); } - ComputationLayout* mutable_device_entry_computation_layout() { - return config_.mutable_device_entry_computation_layout(); - } - - const ComputationLayout& device_entry_computation_layout() const { - return config_.device_entry_computation_layout(); + const ComputationLayout& entry_computation_layout() const { + return config_.entry_computation_layout(); } // Gets the computations in this module. @@ -154,7 +153,7 @@ class HloModule { // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. - std::list MakeComputationPostOrder() const; + std::vector MakeComputationPostOrder() const; // Gets the computations in this module which aren't for fusion nodes. // diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index dae5578a3158fecb8219e518841dec1020b2ca98..07a8c798dbee072db3b75d5e99ca0dcabb5fdf6b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -28,16 +28,14 @@ namespace xla { using tensorflow::strings::StrAppend; -HloModuleConfig::HloModuleConfig() {} - -HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : host_entry_computation_layout_(program_shape), - device_entry_computation_layout_(program_shape) {} +HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts) + : entry_computation_layout_( + ComputationLayout(program_shape, ignore_layouts)) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - host_entry_computation_layout_ = ComputationLayout(program_shape); - device_entry_computation_layout_ = ComputationLayout(program_shape); + entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - host_entry_computation_layout_->parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - host_entry_computation_layout_->result_shape().SerializeAsString()); - for (const ShapeLayout& param_layout : - device_entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); - } - StrAppend( - &key, tensorflow::str_util::Join(params, ", "), ") => ", - device_entry_computation_layout_->result_shape().SerializeAsString()); + entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index cdb0b29a2399b387bc617262032e9083ba079625..074e9c90705d432b8344aebaf3c15aeb41a59fa3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -37,48 +37,34 @@ class HloModuleConfig { // ComputationLayout. The default ctor creates it without -- in this case // accessing entry_computation_layout will CHECK-fail. The ctor accepting a // ProgramShape creates a computation layout using this shape. - HloModuleConfig(); - explicit HloModuleConfig(const ProgramShape& program_shape); + // The layouts in the ProgramShape will be reset to default unless + // ignore_layouts is set to false. + HloModuleConfig() = default; - // Checks if this config has an entry computation layout already. - bool has_host_entry_computation_layout() const { - return host_entry_computation_layout_.has_value(); - } + explicit HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts = true); - bool has_device_entry_computation_layout() const { - return device_entry_computation_layout_.has_value(); + // Checks if this config has an entry computation layout already. + bool has_entry_computation_layout() const { + return entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the on-host layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& host_entry_computation_layout() const { - CHECK(host_entry_computation_layout_.has_value()); - return *host_entry_computation_layout_; - } - - // Returns a mutable pointer to the layout of the on-host entry computation. + // Returns a constant reference to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_host_entry_computation_layout() { - CHECK(host_entry_computation_layout_.has_value()); - return &(*host_entry_computation_layout_); - } - - // Returns a constant reference to the on-device layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& device_entry_computation_layout() const { - CHECK(device_entry_computation_layout_.has_value()); - return *device_entry_computation_layout_; + const ComputationLayout& entry_computation_layout() const { + CHECK(entry_computation_layout_.has_value()); + return *entry_computation_layout_; } - // Returns a mutable pointer to the layout of the on-device entry computation. + // Returns a mutable pointer to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_device_entry_computation_layout() { - CHECK(device_entry_computation_layout_.has_value()); - return &(*device_entry_computation_layout_); + ComputationLayout* mutable_entry_computation_layout() { + CHECK(entry_computation_layout_.has_value()); + return &(*entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -127,8 +113,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional host_entry_computation_layout_; - tensorflow::gtl::optional device_entry_computation_layout_; + tensorflow::gtl::optional entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 5a0d1e264eb5095ff53721416ebcf4842a063f97..21a9b7291acc9e0066a9061facd13ab5acbf0bac 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -277,7 +277,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( tensorflow::gtl::ArraySlice computations) { - std::list post_order; + std::vector post_order; auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index a35546f5f41b149d119ee141fd734da8bfd055b2..05e47a698f3b1d6345b183fb88b588a413063595 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -81,7 +81,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ @@ -94,6 +94,7 @@ namespace xla { V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ + V(kXor, "xor") \ V(kLt, "less-than", kHloOpcodeIsComparison) \ V(kMap, "map", kHloOpcodeIsVariadic) \ V(kMaximum, "maximum") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 774345124b4ad62e35d9423a23f1dbaa28e44d80..6f3f83f63a05fafaa3f3ddcff8a7cac7cb7b06d5 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,7 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index fef475380c5c810e1c4712406dde6b1135be3d97..6ffed62a096043ecdd3609842da8161290d92e57 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -327,22 +327,15 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(p) - ->CopyLayoutFromShape(param_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(result_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } - return true; } @@ -516,7 +509,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -545,6 +537,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { @@ -590,24 +583,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; optional> replica_group_ids; optional barrier; + optional all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; attrs["replica_group_ids"] = { /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, + &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { instruction = builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "")); + barrier ? *barrier : "", all_reduce_id)); } else { instruction = builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "")); + shape, operands, *to_apply, {}, barrier ? *barrier : "", + all_reduce_id)); } break; } @@ -620,12 +616,33 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } - case HloOpcode::kGenerateToken: { + case HloOpcode::kAfterAll: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateGenerateToken(operands)); + instruction = + builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); + break; + } + case HloOpcode::kSort: { + auto loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + switch (operands.size()) { + case 1: + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, /*keys=*/operands[0])); + break; + case 2: + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, + /*keys=*/operands[0], /*values=*/operands[1])); + break; + default: + return Error(loc, StrCat("expects either 1 or 2 operands, but has ", + operands.size(), " operands")); + } break; } case HloOpcode::kTuple: { @@ -981,23 +998,53 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kInfeed: { optional config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateInfeed(shape, config ? *config : "")); + // We need to know the infeed data shape to construct the infeed + // instruction. This is the zero-th element of the tuple-shaped output of + // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail + // if the shape is not a non-empty tuple, so add guard so an error message + // can be emitted instead of a check fail + if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { + return Error(lexer_.GetLoc(), + "infeed must have a non-empty tuple shape"); + } + + if (operands.empty()) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : "")); + } else if (operands.size() == 1) { + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), operands[0], + config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "infeed must have exactly zero or one operands"); + } break; } case HloOpcode::kOutfeed: { optional config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - operands[0]->shape(), operands[0], config ? *config : "")); + if (operands.size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + operands[0]->shape(), operands[0], config ? *config : "")); + } else if (operands.size() == 2) { + instruction = builder->AddInstruction( + HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], + operands[1], config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "outfeed must have exactly one or two operands"); + } break; } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d551400d1ec62d659399e930529e4a4aa7bfaa7d..504ea3fe7adc09da46318198d6c6578c2bc932db 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -795,10 +795,14 @@ ENTRY ReduceR3ToR2.v3 { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - infeed = (u32[3]{0}, pred[]) infeed() - outfeed = () outfeed(infeed) - ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() - outfeed.1 = () outfeed(infeed.1) + token = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) } )" @@ -826,6 +830,31 @@ ENTRY ReducePrecision { ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 } +)" +}, +// Sort (Key) +{ +"SortKey", +R"(HloModule sort + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x) +} + +)" +}, +// Sort (Key, Value) +{ +"SortKeyValue", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values) +} + )" }, // Conditional @@ -1302,7 +1331,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); @@ -1418,5 +1447,15 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); } +TEST_F(HloParserTest, NontupleInfeed) { + const string original = R"(HloModule nontuple_infeed: +ENTRY nontuple_infeed { + token = token[] after-all() + ROOT infeed = pred[] infeed(token) +})"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "infeed must have a non-empty tuple shape"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index d45038f1f4a2e4aa19234eec93fdc9a068a902e1..2418c19f3de7b036d7ef52d3a6db11de6316203b 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -61,7 +61,7 @@ bool AllOperandsAreConstants(const HloInstruction& instruction) { } HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction) { for (HloInstruction* op : instruction->operands()) { if (matcher(op)) { @@ -72,7 +72,7 @@ HloInstruction* GetMatchingOperand( } bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand) { CHECK_EQ(instruction->operand_count(), 2); diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c79347bbf9d6146943b7b787f713369cb37fadee..c0826a6aee1f693484207a86ec258c6604d92318 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -45,7 +45,7 @@ bool IsScalarConstant(const HloInstruction* instruction); // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction); // Returns whether a binary instruction has a matching operand. Sets @@ -53,7 +53,7 @@ HloInstruction* GetMatchingOperand( // other_operand. Note: in the case where both operands match, the first operand // of the instruction is returned. bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand); diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 4738e46f8aeb96a4c25d04b3246bd21f644fe3ea..01b088a957554821e65db7bf9cedf334db49728f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - const std::list& instructions) + tensorflow::gtl::ArraySlice instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 69bb2b3cee6dafe058c45b4e74e93401bea2cfc9..48215d32a8284919cce6beb1663e6a723eefc1c4 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -41,7 +41,8 @@ class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given // instructions. - explicit HloReachabilityMap(const std::list& instructions); + explicit HloReachabilityMap( + tensorflow::gtl::ArraySlice instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 62c07d7fac93618a83b3b6111aec1e93309a0761..59a8800a7d6e9417c0e561db45341c912ad20464 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1244,7 +1244,7 @@ StatusOr HloRematerialization::Run( // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); } // Compute peak memory usage of all computations in the module called in a diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index e1f9d8efd4974055947438c8a2e15cb77d1b5c75..b2725e2918ce76248d9f2cdbb2a6e5a63226bf9a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -98,8 +98,10 @@ StatusOr HloRunner::TransferLiteralToDevice( backend().transfer_manager()->AllocateScopedShapedBuffer( literal.shape(), backend().memory_allocator(), backend().default_device_ordinal())); + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, buffer)); + stream.get(), literal, buffer)); return std::move(buffer); } @@ -127,8 +129,10 @@ StatusOr> HloRunner::TransferLiteralsToDevice( StatusOr> HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { - return backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), buffer); + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), + buffer); } StatusOr> HloRunner::Execute( @@ -176,8 +180,12 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( TF_ASSIGN_OR_RETURN(std::unique_ptr executable, CreateExecutable(std::move(module), run_hlo_passes)); - return executable->ExecuteOnStreamWrapper(&service_run_options, - /*profile=*/profile, arguments); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + return std::move(retval); } StatusOr HloRunner::ExecuteWithDeviceBuffers( @@ -237,7 +245,7 @@ StatusOr>> HloRunner::ExecuteReplicated( backend().transfer_manager()->AllocateScopedShapedBuffer( argument->shape(), backend().memory_allocator(), device)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, *argument, argument_buffer)); + streams.back().get(), *argument, argument_buffer)); argument_buffers.push_back(std::move(argument_buffer)); argument_buffer_ptrs[index++] = &argument_buffers.back(); } @@ -305,9 +313,10 @@ StatusOr>> HloRunner::ExecuteReplicated( std::vector> exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { + TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, backend().transfer_manager()->TransferLiteralFromDevice( - streams[i]->parent(), results[i])); + streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); } return std::move(exec_results); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 641b9ecec9c55ab0d14c28a5c5e84b00c2322499..c6d3909af6103949daf4b0ab6be9b74724461e30 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -399,12 +399,9 @@ StatusOr> DFSMemoryScheduler( const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - // This ordering is based on DFS post-order, with a heuristic to decide which - // operand to visit first. The heuristic is based on 'extra_users', which is - // simply users-1 for each instruction. By subtracting 1, we're saying that - // instructions with no users or a single user don't count; instructions with - // lots of fan-out will be visited earlier. + // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; + int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -413,6 +410,11 @@ StatusOr> DFSMemoryScheduler( total_sizes[hlo] = 0; continue; } + // This ordering is based on DFS post-order, with a heuristic to decide + // which operand to visit first. The heuristic is based on 'extra_users', + // which is simply users-1 for each instruction. By subtracting 1, we're + // saying that instructions with no users or a single user don't count; + // instructions with lots of fan-out will be visited earlier. extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); @@ -428,10 +430,13 @@ StatusOr> DFSMemoryScheduler( // lead to it. But computation is a DAG, so we are double-counting nodes, // which can lead to overflows for large programs. // cumulative_total_size caps the size to prevent overflows. + // Same for total_hlos: it prevents overflows on very large and branchy + // models, where the number of paths is exponential to the number of nodes. // NOTE(dimvar): this is quite ugly and should be changed. It's unclear // why we care about transitive sizes; when scheduling a node, its input // and output buffers should be all that matters, not its "history". total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); + extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 9fb15df7c26951fb7f0d62b0d6533d6312e7a4d5..268b4727bcbed42ba71526f1d5ef5c887e941930 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -100,6 +100,29 @@ bool HloSharding::UsesDevice(int64 device) const { std::find(devices.begin(), devices.end(), device) != devices.end(); } +std::map HloSharding::UsedDevices(int64* count) const { + int64 element_count = 1; + std::map device_map; + if (IsTuple()) { + for (auto& tuple_element_sharding : tuple_elements()) { + auto unique_device = tuple_element_sharding.UniqueDevice(); + if (unique_device.ok()) { + device_map[unique_device.ValueOrDie()] += 1; + } + } + element_count = tuple_elements().size(); + } else { + auto unique_device = UniqueDevice(); + if (unique_device.ok()) { + device_map[unique_device.ValueOrDie()] += 1; + } + } + if (count != nullptr) { + *count = element_count; + } + return device_map; +} + std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); @@ -439,6 +462,27 @@ tensorflow::gtl::optional HloSharding::ExtractSingleSharding() return tuple_elements_.front(); } +size_t HloSharding::Hash() const { + if (!tuple_) { + size_t h = 0; + for (const auto& element : tuple_elements_) { + h = tensorflow::Hash64Combine(h, element.Hash()); + } + return h; + } + if (replicated_) { + return 0; + } + size_t h = 0; + for (uint32 v : tile_assignment_) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + for (uint32 v : tile_shape_.dimensions()) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + return h; +} + std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { out << sharding.ToString(); return out; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 6a744e0247273e25c5de3143b7bbba2b79ee816a..34324d2058efe804cda486600dabd8a62cb84fda 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -19,7 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ +#include #include +#include #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -118,6 +120,14 @@ class HloSharding { // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; + // Retrieves an histogram of the devices used by the sharding. The returned + // map has the device number as key, and the occurrence count as value. + // If a sharding does not have a device, it will not be incuded in the + // histogram. The count argument, if not nullptr, will receive the total + // number of elements this sharding is made of (one for array, N leaves for + // tuples). + std::map UsedDevices(int64* count) const; + // Returns the tile that should be executed on the given device. // REQUIRES: !IsTuple() std::vector TileIndexForDevice(int64 device) const; @@ -179,26 +189,7 @@ class HloSharding { } bool operator!=(const HloSharding& other) const { return !(*this == other); } - size_t Hash() const { - if (!tuple_) { - size_t h = 0; - for (const auto& element : tuple_elements_) { - h = tensorflow::Hash64Combine(h, element.Hash()); - } - return h; - } - if (replicated_) { - return 0; - } - size_t h = 0; - for (uint32 v : tile_assignment_) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } - return h; - } + size_t Hash() const; struct Hasher { size_t operator()(const HloSharding& sharding) const { @@ -240,6 +231,12 @@ class HloSharding { tuple_(false), tile_shape_(), tile_assignment_({0}) {} + // device_id values: + // -2: magic number to mean unassigned device, used by spatial partitioning + // -1: the id of the host + // 0 or positive: the id of a device + // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once + // we have fully switched to the side-effect tokens. explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 748273a43cecca7a9c7392bb84f0e4c7133cfb14..39036e205e76979e7da08246cd030ebd17e52f76 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -377,7 +377,7 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { } string ShardingMetadata::ToString() const { - return sharding_ != nullptr ? sharding_->ToString() : "None"; + return sharding_ != nullptr ? sharding_->ToString() : "{}"; } Status ShardingMetadata::NormalizeInstructions( diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 1d6cd4cb2308fd09c7511e390a146a5224f253a3..27c9529b11181bed7a5d3977eeac7a66c066f8f8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -106,22 +108,50 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast(instruction); + // Infeed has an optional single token operand. + // TODO(b/80000000): Update when token is not optional. + if (infeed->operand_count() == 1 && + !ShapeUtil::Equal(infeed->operand(0)->shape(), + ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected infeed operand to be token-shaped, actual shape is %s:\n%s", + ShapeUtil::HumanString(infeed->operand(0)->shape()).c_str(), + infeed->ToString().c_str()); + } + + // The output of infeed is a tuple containing the data value and a token. + return CheckShape(infeed, + ShapeUtil::MakeTupleShape( + {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()})); +} + +Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + HloOutfeedInstruction* outfeed = Cast(instruction); + // Outfeed has an optional token operand (operand 1). + // TODO(b/80000000): Update when token is not optional. + if (outfeed->operand_count() == 2 && + !ShapeUtil::Equal(outfeed->operand(1)->shape(), + ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected operand 1 of outfeed to be a token, actual shape is %s:\n%s", + ShapeUtil::HumanString(outfeed->operand(1)->shape()).c_str(), + outfeed->ToString().c_str()); + } -Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the - // host. The shape of the instruction itself is always nil because the outfeed - // produces no HLO value in the graph. + // host. The shape of the instruction itself is always a token. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed to have shape compatible with operand's shape %s, " + "Expected outfeed shape to be compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), outfeed->ToString().c_str()); } - return CheckShape(outfeed, ShapeUtil::MakeNil()); + return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { @@ -137,7 +167,16 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - return CheckUnaryShape(sort); + if (sort->operand_count() == 2 && + !ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(1)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys and " + "the values. Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), + ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + } + return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { @@ -426,13 +465,12 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } -Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { +Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { std::vector operand_shapes; for (const HloInstruction* operand : token->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(token, - ShapeInference::InferGenerateTokenShape(operand_shapes)); + return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -786,8 +824,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); - if (!ShapeUtil::IsScalar(operand_shape) && - !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { + if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." "Found non-compatible shapes for instruction %s.\n" @@ -815,9 +852,10 @@ bool ShapeContainsToken(const Shape& shape) { } // Verifies that all types entering and exiting the entry computation are -// legal. For example, TOKEN types have no Literal representation and cannot be -// on the interface of the entry computation (parameters and root instruction). +// legal. Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { HloInstruction* param = module.entry_computation()->parameter_instruction(i); @@ -827,14 +865,6 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { ShapeUtil::HumanString(param->shape()).c_str()); } } - if (ShapeContainsToken( - module.entry_computation()->root_instruction()->shape())) { - return InternalError( - "Entry root is or contains a token shape: %s", - ShapeUtil::HumanString( - module.entry_computation()->root_instruction()->shape()) - .c_str()); - } return Status::OK(); } @@ -881,7 +911,9 @@ StatusOr HloVerifier::Run(HloModule* module) { << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->IsElementwise()) { + } else if (instruction->opcode() != + HloOpcode::kRng /* Rng operands are always scalar. */ + && instruction->IsElementwise()) { TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7283b3e7dcdbed5be18a1da1571287cf0c089288..da6b5d222206fe9bfcbf5157dc524ed46edaaac7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,7 +81,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; - Status HandleGenerateToken(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8b3fa6c1572cf0ed91fc427722edcb23d8b8529d..1985d20578677ae68b244023c4640454b004bf49 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -28,6 +28,7 @@ namespace { using Analysis = IndexedArrayAnalysis; using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; +using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; using tensorflow::gtl::ArraySlice; using tensorflow::str_util::Join; @@ -52,6 +53,13 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { "(constant ", ShapeUtil::HumanString(root->shape()), ")"); } + case Array::kReshaped: { + ReshapedArray* reshaped_array = root->as(); + return tensorflow::strings::StrCat( + "(reshape ", ToString(reshaped_array->operand(), print_constants), + " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); + } + case Array::kScalarIndexedConstant: case Array::kScalarIndexed: { auto* indexed_array = root->as(); @@ -239,15 +247,40 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( tensorflow::gtl::ArraySlice window_bounds, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); - if (!c_binary_search(dim_numbers.elided_window_dims(), - dim_numbers.gather_dims_to_operand_dims(0))) { + + // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should + // it become relevant. + + if (dim_numbers.elided_window_dims_size() != 1 || + dim_numbers.elided_window_dims(0) != + dim_numbers.gather_dims_to_operand_dims(0)) { + VLOG(3) << "ComputeArrayForGather: gather operations must elide " + "gather_dims_to_operand_dims[0] and " + "gather_dims_to_operand_dims[0] only"; return nullptr; } + // ScalarIndexedArray cannot represent gathers that "slice" along some + // dimensions -- for instance it cannot represent a gather that picks 5 [2,3] + // arrays from an array of size [7,4,6]. We check that condition down below: + + for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { + if (i != dim_numbers.elided_window_dims(0) && + source->shape().dimensions(i) != window_bounds[i]) { + VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + << "] != source->shape().dimensions(" << i << ") -- " + << source->shape().dimensions(i) << " vs. " << window_bounds[i] + << " with dim_numbers.elided_window_dims(0) = " + << dim_numbers.elided_window_dims(0); + return nullptr; + } + } + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); std::vector output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { @@ -336,7 +369,11 @@ std::vector ComputeReshapePassthroughDimPairs( // result_subarray_size does not include the elements in the current // `result_dim` dimension (we multiply in result_shape[result_dim] at the // end of loop body) so candidate_operand_dim can never be zero. - CHECK_NE(candidate_operand_dim, 0); + CHECK_NE(candidate_operand_dim, 0) + << "result_dim = " << result_dim + << ", result_subarray_size = " << result_subarray_size + << ", result_shape = [" << Join(result_shape, ",") << "]" + << ", operand_shape = [" << Join(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -357,7 +394,7 @@ std::vector ComputeReshapePassthroughDimPairs( }); VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "]"; + << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; } DCHECK(c_is_sorted( @@ -398,6 +435,10 @@ int64 MapPassthroughOperandDimToResultDim( int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, ArraySlice result_shape, int64 source_passthrough_dim) { + VLOG(3) << "FindSourcePositionForPassthroughResultDim([" + << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << "], " << source_passthrough_dim << ")"; + int64 indexed_source_subarray_size = std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, operand_shape.end(), 1, std::multiplies()); @@ -405,15 +446,191 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); } +Shape StripDegenerateDimensions(const Shape& shape) { + DimensionVector new_dims; + c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); + return ShapeUtil::MakeShape(shape.element_type(), new_dims); +} }; // namespace -StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( - const Shape& shape, Array* operand) { - auto* scalar_indexed = dynamic_cast(operand); - if (!scalar_indexed) { +StatusOr +IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand) { + const Shape& shape = operand->shape(); + if (!ShapeUtil::HasDegenerateDimensions(shape)) { + return operand; + } + + // We only need to reshape out the degenerate dims from the indices and the + // source (except the source dim). + + const Shape& source_shape = operand->source()->shape(); + DimensionVector new_source_shape_dims; + for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) { + if (i == operand->source_dim() || source_shape.dimensions(i) != 1) { + new_source_shape_dims.push_back(source_shape.dimensions(i)); + } + } + + Shape new_source_shape = + ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims); + Shape new_indices_shape = + StripDegenerateDimensions(operand->indices()->shape()); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + TF_ASSIGN_OR_RETURN( + Array* const new_indices, + ComputeArrayForReshape(new_indices_shape, operand->indices())); + + // Build the new output dims while keeping track of the degenerate dims that + // will no longer be present. + DimensionVector new_output_dims; + int64 degenerate_dims_seen = 0; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_dims_seen++; + } else if (ArrayContains(operand->output_dims(), i)) { + new_output_dims.push_back(i - degenerate_dims_seen); + } + } + + // Similarly, build the new source dim while keeping track of the degenerate + // dims that will no longer be present. + int64 degenerate_dims_before_source_dim = + std::count(source_shape.dimensions().begin(), + source_shape.dimensions().begin() + operand->source_dim(), 1); + int64 new_source_dim = + operand->source_dim() - degenerate_dims_before_source_dim; + + return ConstructScalarIndexedArray( + new_source, new_indices, new_source_dim, + InlinedVectorToVector(new_output_dims), + StripDegenerateDimensions(operand->shape())); +} + +StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims) { + if (degenerate_dims.empty()) { + return operand; + } + + CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape())); + + DimensionVector new_output_dims = [&]() { + // To make things easy we use a "scratch" buffer of bools where the i'th + // element is true iff the i'th component of the result index is an output + // index. + + gtl::InlinedVector output_dims_bitvector( + operand->shape().dimensions_size()); + for (int64 output_dim : operand->output_dims()) { + output_dims_bitvector[output_dim] = true; + } + + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&output_dims_bitvector, degenerate_dim, false); + } + + DimensionVector result; + result.reserve(operand->output_dims().size()); + for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) { + if (output_dims_bitvector[i]) { + result.push_back(i); + } + } + + return result; + }(); + + DimensionVector new_result_shape_dims; + c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&new_result_shape_dims, degenerate_dim, 1); + } + + DimensionVector new_source_shape_dims = new_result_shape_dims; + for (int64 output_dim : new_output_dims) { + EraseAt(&new_source_shape_dims, output_dim); + } + + int64 new_source_dim = [&]() { + for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) { + int64 non_degenerate_dims_seen = 0; + if (non_degenerate_dims_seen == operand->source_dim()) { + return i; + } + if (new_source_shape_dims[new_source_dim] != 1) { + non_degenerate_dims_seen++; + } + } + LOG(FATAL) << "Did not find source dim in " << ToString(operand); + }(); + + int64 source_dim_size = + operand->source()->shape().dimensions(operand->source_dim()); + InsertAt(&new_source_shape_dims, /*index=*/new_source_dim, + /*value=*/source_dim_size); + + Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_source_shape_dims); + Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_result_shape_dims); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + return ConstructScalarIndexedArray( + new_source, operand->indices(), new_source_dim, + InlinedVectorToVector(new_output_dims), new_result_shape); +} + +StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand) { + VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")"; + + // To make things easier on ourselves, instead of directly trying to fold the + // reshape of `operand` to `shape`, we call + // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and + // handle the degenerate dimensions here by inserting reshapes. + + TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims, + ReshapeToRemoveDegenerateDims(operand)); + + Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape); + TF_ASSIGN_OR_RETURN( + ScalarIndexedArray* const folded_reshape_without_degenerate_dims, + FoldReshapeOfGatherNoDegenerateDims( + output_shape_without_degenerate_dims, + operand_without_degenerate_dims->as())); + + if (folded_reshape_without_degenerate_dims == nullptr) { return nullptr; } + DimensionVector degenerate_result_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_result_dims.push_back(i); + } + } + + return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims, + degenerate_result_dims); +} + +StatusOr +IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) { + VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed) + << ")"; + CHECK(!ShapeUtil::HasDegenerateDimensions(shape)); + CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape())); + // Try to fold Reshape(ScalarIndexed(Const, Indices)) // => ScalarIndexed(Const', Indices) // @@ -464,7 +681,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( std::vector reshape_passthrough_dims = ComputeReshapePassthroughDimPairs( - /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()), + /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()), /*result_shape=*/AsInt64Slice(shape.dimensions())); auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { @@ -474,6 +691,8 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( if (!c_all_of(scalar_indexed->output_dims(), is_reshape_passthrough_operand_dim)) { + VLOG(3) << "Not all output dims are passthrough dims " + << ToString(scalar_indexed); return nullptr; } @@ -527,6 +746,11 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( // (a.k.a. isn't pass-through) than the [3,5,2] array. if (source_dim_for_new_scalar_indexed_node == -1) { + VLOG(3) << "Could not compute the source dim for the new scalar indexed " + "node: scalar_indexed_source_shape = [" + << Join(scalar_indexed_source_shape.dimensions(), ",") + << "] and new_scalar_indexed_source_shape = [" + << Join(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -534,6 +758,10 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l, + std::multiplies()), + ShapeUtil::ElementsIn(scalar_indexed_source_shape)); + CHECK(IsReshapePassthroughOperandDim( ComputeReshapePassthroughDimPairs( /*operand_shape=*/AsInt64Slice( @@ -564,6 +792,31 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( output_dims_for_new_scalar_indexed_node, shape); } +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + if (ShapeUtil::Compatible(operand->shape(), shape)) { + return operand; + } + + if (auto* scalar_indexed = + dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather, + FoldReshapeOfGather(shape, scalar_indexed)); + if (reshape_folded_into_gather) { + return reshape_folded_into_gather; + } + } + + if (auto* constant_array = dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Literal* const new_literal, + TakeOwnership(constant_array->literal()->Reshape( + AsInt64Slice(shape.dimensions())))); + return Construct(new_literal); + } + + return Construct(operand, shape); +} + StatusOr IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, Array* lhs, diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index ce92fd2919c90fa8a2fb7b796ed6f0fdaf48fe62..8684430231c1929f82508e3675f1c275c42b6149 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -39,7 +39,13 @@ class IndexedArrayAnalysis { // Array instances are immutable once created. class Array { public: - enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + enum Kind { + kUnknown, + kConstant, + kReshaped, + kScalarIndexedConstant, + kScalarIndexed + }; virtual Kind kind() const = 0; virtual const Shape& shape() const = 0; @@ -96,6 +102,27 @@ class IndexedArrayAnalysis { friend class IndexedArrayAnalysis; }; + // Represents an Array that is a reshape of another Array. + class ReshapedArray : public Array { + public: + Kind kind() const override { return kReshaped; } + + // The array to reshape. + Array* operand() const { return operand_; } + + // The output shape. + const Shape& shape() const override { return shape_; } + + private: + explicit ReshapedArray(Array* operand, Shape shape) + : operand_(operand), shape_(shape) {} + + Array* operand_; + const Shape shape_; + + friend class IndexedArrayAnalysis; + }; + // --------------------------------------------------------------------------- // Indexed Array Overview // --------------------------------------------------------------------------- @@ -266,6 +293,21 @@ class IndexedArrayAnalysis { ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice output_dims, Shape shape); + // Reshapes a scalar-indexed node to remove the degenerate dimensions in its + // output. The result is always a scalar-indexed node. + StatusOr ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand); + + // Reshapes a scalar-indexed node such that the result has the degenerate + // dimensions `degenerate_dims`. The result is always a scalar-indexed node. + StatusOr ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims); + + StatusOr FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand); + StatusOr FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 373556ebeba883f7dc2116bdf0ffc3274182f775..fc2befe05b18651502c42b9892e766145d85f2e8 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -34,6 +36,27 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { } private: + // Replaces seqences of whitespace with a single space. This makes the + // strings being matched against "whitespace insensitive" which lets us indent + // them for readability. + string CanonicalizeWhitespace(const string& text) { + string result; + + for (char c : text) { + if (!isspace(c)) { + result.push_back(c); + } else if (!result.empty() && result.back() != ' ') { + result.push_back(' '); + } + } + + while (!result.empty() && result.back() == ' ') { + result.pop_back(); + } + + return result; + } + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, const string& root_expression, bool print_constants) { @@ -44,10 +67,10 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { IndexedArrayAnalysis::Array* const array_result, indexed_tensor_analysis.GetArrayFor( module().entry_computation()->root_instruction())); - string string_result = - indexed_tensor_analysis.ToString(array_result, print_constants); + string string_result = CanonicalizeWhitespace( + indexed_tensor_analysis.ToString(array_result, print_constants)); LOG(INFO) << string_result; - ASSERT_EQ(string_result, root_expression); + ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression)); } }; @@ -91,6 +114,82 @@ ENTRY main { hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); } +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5,2] parameter(0) + ROOT gather = s32[5] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2,3] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={2,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,2} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { string hlo_text = R"( HloModule SimpleGather @@ -273,7 +372,157 @@ ENTRY main { "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); } -TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) { +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + + i.0 = s64[1,3]{1,0} parameter(0) + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, + elided_window_dims={0}, gather_dims_to_operand_dims={0}, + index_vector_dim=2, window_bounds={1,3} + + i.1 = s64[1] parameter(1) + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, + elided_window_dims={1}, gather_dims_to_operand_dims={1}, + index_vector_dim=1, window_bounds={1,1,3} + + ROOT reshape = s32[1,3]{1,0} reshape(g.1) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,3]) + (reshape + (scalar-indexed %i.0 %i.1 1->[1]) + to s64[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[1,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,2,6] constant(s32[1,2,6]{{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[1] parameter(0) + gather = s32[1,1,6] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={1,1,6} + ROOT reshape = s32[1,1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1,5] parameter(0) + gather = s32[1,5,6] gather(operand, indices), + output_window_dims={2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} + ROOT reshape = s32[1,1,5,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6] s32[2,1,1,6] { + { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, + { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + (reshape %indices to s32[5]) + 0->[2]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { string hlo_text = R"( HloModule ReshapeOfGather @@ -290,10 +539,19 @@ ENTRY main { } )"; - AssertArrayForRootExpressionIs(hlo_text, "%reshape"); + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); } -TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) { +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { string hlo_text = R"( HloModule ReshapeOfGather @@ -313,7 +571,48 @@ ENTRY main { } )"; - AssertArrayForRootExpressionIs(hlo_text, "%reshape"); + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,5,2]) + %indices + 1->[2]) + to s32[6,7]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4,1] constant(s32[3,4,1]{ + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6,1] gather(operand, indices), + output_window_dims={1,3}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4,1} + ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4,1]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3,1]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); } TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index abedb4063d3763516e66cff36633dbd90c8cafde..088cc2622695c7724dae2b6cde28fecd40547445 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -83,6 +83,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReal: @@ -96,7 +97,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; @@ -237,6 +238,30 @@ InstructionFusion::ComputeGloballyUnfusable( if (EffectivelyAtMostUnary(producer)) { continue; } + + // If the total size of the inputs is less than or equal to the total size + // of the outputs for the producer then duplicating it won't increase the + // memory traffic. In that case, we do not forbid fusion of the operation + // here. + auto total_size = [](const Shape& shape) { + int64 size = 0; + ShapeUtil::ForEachSubshape( + shape, + [&size](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + size += ShapeUtil::ElementsIn(subshape); + } + }); + return size; + }; + int64 operands_size = 0; + for (const HloInstruction* op : producer->operands()) { + operands_size += total_size(op->shape()); + } + if (operands_size <= total_size(producer->shape())) { + continue; + } + // Otherwise we will forbid fusing the op unless we can fuse it into // all of its consumers on all paths. // @@ -281,10 +306,8 @@ StatusOr InstructionFusion::Run(HloModule* module) { // map from HloInstruction* to the instruction's index in the vector. An // instruction is "removed" from the vector by setting it's element to // nullptr. - std::list post_order_list = + std::vector post_order = computation_->MakeInstructionPostOrder(); - std::vector post_order(post_order_list.begin(), - post_order_list.end()); tensorflow::gtl::FlatMap post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c1666530687f2f8407a9dcb4e271c9d95552a689..9f8f4bda875cdff5e20fa8ca8eeecaa1140e2b9c 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,7 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 029e71058a7373b9310c6d9ffdb65f72ca28e5af..9816acf6507a0ed5391cf4f1c94ccd0f27f5227a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -75,9 +75,9 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // consumes. std::vector> arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr arg_literal, - transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice( + run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } @@ -96,7 +96,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( result_literal->shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - executor, *result_literal, result)); + run_options->stream(), *result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 97e9fa2c8e8ecd918ffe3df2fd4e731f3b91e6db..4fb67bd0b72fc591c1ffa76ebb0513bf14ed3737 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -53,6 +53,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } @@ -61,6 +62,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index eb469e77a08b976b91ed5e3cdea304a8148f73c5..36fdfa868dfbfaf9fbf353dd6623058d518fec04 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -175,41 +175,32 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); - const BufferLayoutConstraint* curr_constraint = - GetBufferLayoutConstraint(buffer); - if (curr_constraint != nullptr) { - if (LayoutUtil::Equal(curr_constraint->layout(), layout)) { + auto iter = buffer_constraints_.find(&buffer); + if (iter != buffer_constraints_.end()) { + const BufferLayoutConstraint& curr_constraint = iter->second; + if (LayoutUtil::Equal(curr_constraint.layout(), layout)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } - if (curr_constraint->mandatory()) { + if (curr_constraint.mandatory()) { return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint->layout()).c_str(), + LayoutUtil::HumanString(curr_constraint.layout()).c_str(), LayoutUtil::HumanString(layout).c_str()); } - } - - auto iter = buffer_constraints_.find(&buffer); - bool overwrite = iter != buffer_constraints_.end(); - if (!overwrite) { + iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); + } else { + TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1) + << buffer.ToString(); iter = buffer_constraints_ .insert(std::make_pair( &buffer, BufferLayoutConstraint(layout, buffer, mandatory, dfs))) .first; - } else { - iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } added_constraints_.push_back(&iter->second); - - // Remove buffer from the set of unconstrained buffers. - TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == - static_cast(!overwrite)); - unconstrained_buffer_ids_.erase(buffer.id()); - return Status::OK(); } @@ -716,7 +707,8 @@ Status CheckParameterLayout(HloInstruction* parameter, const ComputationLayout& computation_layout) { const ShapeLayout& parameter_layout = computation_layout.parameter_layout(parameter->parameter_number()); - if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) { + if (parameter_layout.LayoutIsSet() && + !parameter_layout.MatchesLayoutInShape(parameter->shape())) { return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", @@ -936,6 +928,7 @@ LayoutAssignment::LayoutAssignment( ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), channel_layout_constraints_(channel_constraints) { if (channel_layout_constraints_ != nullptr) { // Save a copy of the input ChannelLayoutConstraints so that we can reset it @@ -944,11 +937,6 @@ LayoutAssignment::LayoutAssignment( } VLOG(1) << "Entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); - // Layouts of all parameter instructions must be set. - for (const ShapeLayout& parameter_layout : - entry_computation_layout_->parameter_layouts()) { - CHECK(parameter_layout.LayoutIsSet()); - } } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1577,6 +1565,13 @@ Status LayoutAssignment::RunOnComputation( // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); + // Prior to applying default layouts, we take note of all HLO instructions + // which lack a layout constraint. + for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) { + unconstrained_layout_instructions_.insert( + points_to_analysis.GetBuffer(buffer_id).instruction()); + } + // While any unconstrained buffers remain, pick an arbitrary buffer, give it a // layout and propagate the change. while (!constraints.unconstrained_buffer_ids().empty()) { @@ -1721,13 +1716,14 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // when seen from an outer instruction, which has across-computation // constraints to impose. // For example, the kWhile instruction needs to enforce the same layouts for - // the parameters and root of the bosy, as well as the condition parameters. + // the parameters and root of the body, as well as the condition parameters. // Similarly, the kConditional instruction needs to enforce the same layouts // for the root of the true and false computations. // So in the first pass, while allowing the layouts to flow to parameters and // root, we also fix up the eventually inconsistent ComputationLayout, which // will be then made mandatory by the second pass. for (int64 i = 0; i < 2; ++i) { + VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass"; TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -1765,10 +1761,12 @@ StatusOr LayoutAssignment::Run(HloModule* module) { Status LayoutAssignment::Init() { computation_layouts_.clear(); + *entry_computation_layout_ = saved_entry_computation_layout_; return Status::OK(); } Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + VLOG(5) << "Clearing previous side effects"; // Clear all the copies which have been added, and all the related // instructions (like GTE and tuples). int64 removed_copies = 0; @@ -1786,6 +1784,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { } } added_copies_.clear(); + unconstrained_layout_instructions_.clear(); if (removed_copies > 0) { TupleSimplifier tuple_simplifier; HloDCE dce; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index eb4cd5936b09145c7ba6351fdc9086d6d0f05bea..b75ecb311a07b996562460fc5d6fbd8e70ac056b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -432,8 +432,13 @@ class LayoutAssignment : public HloPassInterface { Status PropagateComputationLayouts(HloComputation* computation, ComputationLayout* computation_layout); + // The pointer to the ComputationLayout passed as constructor parameter. ComputationLayout* entry_computation_layout_; + // A copy of entry_computation_layout_ used to reset it to the initial values + // during the multiple passes done by the layout assignment operation. + ComputationLayout saved_entry_computation_layout_; + protected: // Sets up the copy instruction according to the characteristic (sharding, // metadata, ...) of the reference instruction. The index argument is used @@ -501,6 +506,11 @@ class LayoutAssignment : public HloPassInterface { // case we have to undo operations due to the multiple passes over the // computations/instructions. ChannelLayoutConstraints channel_constraints_; + + // The set of HLO instructions which lacked any layout constraint, thus + // receiving propagated default layouts. + tensorflow::gtl::FlatSet + unconstrained_layout_instructions_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 62599b376a12808232c703479a0ccfd7a59aa9ad..67e2cf6c777b3ecc86cfa408145b9c3cd0c31df9 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -770,9 +770,13 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction( HloInstruction::CreateParameter(0, tshape, "param")); // Using infeed as layout assignment does not mess up with it. - auto infeed = - false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); - false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); + auto token = + false_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto infeed = false_builder.AddInstruction( + HloInstruction::CreateInfeed(xshape, token, "")); + auto infeed_data = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(xshape, infeed, 0)); + false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = module->AddEmbeddedComputation(false_builder.Build()); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7323abeb2077154f82828bcda3e90eb45a67138a..ea10cef49a4a9aa048b3e0ea443f052645c4912a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,9 +29,9 @@ limitations under the License. namespace xla { namespace llvm_ir { -static void Delinearize(std::vector* multidim, - llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) { +void IrArray::Index::Delinearize(std::vector* multidim, + llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) const { int64 divisor = 1; const Layout& layout = shape.layout(); for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { @@ -48,10 +48,11 @@ static void Delinearize(std::vector* multidim, // useful because cuda-memcheck can't help us much in XLA: Most of our // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. - auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); + auto* quot = + ir_builder->CreateUDiv(linear, GetConstantWithIndexType(divisor)); if (i < layout.minor_to_major_size() - 1) { (*multidim)[dimension] = ir_builder->CreateURem( - quot, ir_builder->getInt64(size_of_current_dimension)); + quot, GetConstantWithIndexType(size_of_current_dimension)); } else { (*multidim)[dimension] = quot; } @@ -65,6 +66,8 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_NE(linear, nullptr); + index_type_ = linear->getType(); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; @@ -77,6 +80,13 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + if (size()) { + index_type_ = multidim_[0]->getType(); + } else { + CHECK_NE(linear_, nullptr); + index_type_ = linear_->getType(); + } + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) @@ -88,6 +98,9 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_GT(multidim_.size(), 0); + index_type_ = multidim[0]->getType(); + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)); } @@ -130,15 +143,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), - llvm::UndefValue::get(builder->getInt64Ty())); + ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = Index(tensorflow::gtl::ArraySlice( multidim_, common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second)) + common_factors[k + 1].second - common_factors[k].second), + index_type_) .Linearize( tensorflow::gtl::ArraySlice( AsInt64Slice(output_shape.dimensions()), @@ -150,9 +163,10 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64 i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { - llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i)); + llvm::Value* divisor = + GetConstantWithIndexType(input_shape.dimensions(i)); if (input_shape.dimensions(i) == 1) { - source_multidim_index[i] = builder->getInt64(0); + source_multidim_index[i] = GetConstantWithIndexType(0); } else if (i == common_factors[k].first) { source_multidim_index[i] = logical_linear_index; } else { @@ -168,14 +182,14 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } - return Index(source_multidim_index); + return Index(source_multidim_index, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfSlice( const Shape& shape, tensorflow::gtl::ArraySlice starts, tensorflow::gtl::ArraySlice strides, llvm::IRBuilder<>* builder) const { - Index source_index(multidim_.size()); + Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; auto type = multidim_[i]->getType(); @@ -224,11 +238,12 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( // the physical index of the element in the buffer. This is like Linearize, // but takes the layout into account. int64 scale = 1; - llvm::Value* linear_index = builder->getInt64(0); + llvm::Value* linear_index = GetConstantWithIndexType(0); for (auto dimension : LayoutUtil::MinorToMajor(shape)) { linear_index = builder->CreateAdd( linear_index, - builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "", + builder->CreateMul(multidim_[dimension], + GetConstantWithIndexType(scale), "", /*HasNUW=*/true, /*HasNSW=*/true), "", /*HasNUW=*/true, /*HasNSW=*/true); scale *= shape.dimensions(dimension); @@ -252,7 +267,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( } if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || !LayoutUtil::HasLayout(shape)) { - return Index(source_index); + return Index(source_index, index_type_); } // High-level idea: we can reuse the linear index if the broadcasted // dimensions are contiguous, and this part of the operation is a bitcast. @@ -274,7 +289,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( bool contiguous_broadcast_dimensions = max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; if (!contiguous_broadcast_dimensions) { - return Index(source_index); + return Index(source_index, index_type_); } // Check if the mapped dimensions are a bitcast. std::vector operand_logical_to_physical = @@ -282,7 +297,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( for (int64 i = 0; i < rank; ++i) { if (operand_logical_to_physical[i] != logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { - return Index(source_index); + return Index(source_index, index_type_); } } llvm::Value* linear = linear_; @@ -291,7 +306,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } if (divisor > 1) { - linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + linear = builder->CreateUDiv( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(divisor)); } if (min_broadcasted_dimension > 0) { int64 mod = 1; @@ -299,7 +316,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( ++i) { mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - linear = builder->CreateURem(linear, builder->getInt64(mod)); + linear = builder->CreateURem( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(mod)); } return Index(source_index, linear, operand_shape); } @@ -309,12 +328,13 @@ llvm::Value* IrArray::Index::Linearize( llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. - llvm::Value* logical_linear_index = builder->getInt64(0); + llvm::Value* logical_linear_index = GetConstantWithIndexType(0); int64 multiplier = 1; for (ssize_t i = size() - 1; i >= 0; --i) { llvm::Value* addend = - builder->CreateMul((*this)[i], builder->getInt64(multiplier), "", + builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "", /*HasNUW=*/true, /*HasNSW=*/true); + addend = builder->CreateZExtOrTrunc(addend, index_type_); logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", /*HasNUW=*/true, /*HasNSW=*/true); multiplier *= dimensions[i]; @@ -349,7 +369,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to // produce better code in some cases. auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + actual_index.push_back( + dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]); } // "base_ptr_" has the type of "*" @@ -357,7 +378,9 @@ llvm::Value* IrArray::EmitArrayElementAddress( // should be computed by // // getelementptr base_ptr_, 0, most major index, ..., most minor index - std::vector gep_indices(1, ir_builder->getInt64(0)); + CHECK_GT(index.size(), 0); + std::vector gep_indices( + 1, llvm::ConstantInt::get(index[0]->getType(), 0)); for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); @@ -410,7 +433,9 @@ IrArray IrArray::CastToShape(const Shape& new_shape, llvm::IRBuilder<>* ir_builder) { Index new_index = index; new_index[which_dimension] = ir_builder->CreateAdd( - index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true, + index[which_dimension], + llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "", + /*HasNUW=*/true, /*HasNSW=*/true); return new_index; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 4c3195c29c859c9eef08e3f6531b059edbebfc47..4648c6d7ac089dbea7e660dd9889d557c8ad7318 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -53,18 +53,38 @@ class IrArray { // multidimensional index, which LLVM DCE can delete. class Index { public: - // Constructs an empty zero-dimensional index. - Index() {} - // Constructs an index of rank "size". Each dimension of the index is // initialized to "value". - explicit Index(size_t size, llvm::Value* value = nullptr) - : multidim_(size, value) {} + explicit Index(size_t size, llvm::Value* value) + : multidim_(size, value), index_type_(value->getType()) { + CHECK_NE(index_type_, nullptr); + } + + // Constructs an index of rank "size". Each dimension of the index is + // initialized to nullptr. + explicit Index(llvm::Type* index_ty, size_t size = 0) + : multidim_(size, nullptr), index_type_(index_ty) { + CHECK(index_ty->isIntegerTy()); + } // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim) - : multidim_(multidim.begin(), multidim.end()) {} + explicit Index(tensorflow::gtl::ArraySlice multidim, + llvm::Type* index_ty = nullptr) + : multidim_(multidim.begin(), multidim.end()) { + if (size() == 0) { + index_type_ = index_ty; + } else { + index_type_ = (*this)[0]->getType(); + if (index_ty != nullptr) { + CHECK_EQ(index_type_, index_ty); + } + } + CHECK_NE(index_type_, nullptr); + CHECK(c_all_of(multidim, [&](llvm::Value* v) { + return index_type_ == v->getType(); + })); + } // Constructs an index from linear index "linear" and computes the // multi-dimensional index from "linear" and "shape". "ir_builder" is the IR @@ -154,6 +174,15 @@ class IrArray { llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, llvm::IRBuilder<>* builder) const; + llvm::Type* GetType() const { return index_type_; } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + // The LLVM function makes sure that the value can be represented by the + // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const + // APInt &V). + return llvm::ConstantInt::get(index_type_, c); + } + private: // Changing the multi-dimensional index invalidates the linear index. std::vector& multidim() { @@ -161,6 +190,9 @@ class IrArray { return multidim_; } + void Delinearize(std::vector* multidim, llvm::Value* linear, + const Shape& shape, llvm::IRBuilder<>* ir_builder) const; + std::vector multidim_; // These values are purely for efficiency; `multidim_` is enough to find the @@ -177,6 +209,8 @@ class IrArray { llvm::Value* linear_ = nullptr; Layout layout_; std::vector dims_; + + llvm::Type* index_type_; }; // Default constructor. Constructs an IrArray in a null status. diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index e17c649e5272a9e7c0d5126083ab76542abfdf48..6f7a9d94e3b9e59b2dfe12b9673335a904ae78b6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -125,8 +125,8 @@ class KernelSupportLibrary { llvm::Value* is_first_iteration)>& for_body_generator) { return For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, @@ -135,8 +135,8 @@ class KernelSupportLibrary { llvm::Value* is_first_iteration)>& for_body_generator) { ForReturnVoid(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } Status For( @@ -165,7 +165,7 @@ class KernelSupportLibrary { tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - return For(name, start, end, ir_builder_->getInt64(step), + return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) -> Status { return for_body_generator(indvar); @@ -176,7 +176,8 @@ class KernelSupportLibrary { tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, ir_builder_->getInt64(step), + ForReturnVoid(name, start, end, + llvm::ConstantInt::get(start->getType(), step), for_body_generator); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 9f867014fb015845448c4fcf9c165750f8a61935..c9ae7d3afd5cdc21157732f6d0dfa824268e86bd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -97,7 +97,7 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->SetInsertPoint(&func->getEntryBlock(), func->getEntryBlock().getFirstInsertionPt()); llvm::Value* indvar_address = - ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr, + ir_builder->CreateAlloca(start_index_->getType(), nullptr, AsStringRef(GetQualifiedName("invar_address"))); // Preheader basic block. @@ -185,7 +185,7 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* end_index, UnrollMode unroll_mode, bool prevent_vectorization) { - return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), + return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), unroll_mode, prevent_vectorization); } @@ -223,8 +223,8 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), unroll_mode, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), unroll_mode, prevent_vectorization); } @@ -234,9 +234,9 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), unroll_mode, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), + GetConstantWithIndexType(stride), unroll_mode, prevent_vectorization); } @@ -250,7 +250,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice dimensions, tensorflow::StringPiece suffix) { - llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr); + llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 4e403cd994874c27453574283c5c573c876628db..0dd5b9d3b2656af68f76c2adfcb1f3a1385eeb91 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -177,15 +177,21 @@ class ForLoop { // A simple class for constructing nested for-loops. class ForLoopNest { public: - explicit ForLoopNest(llvm::IRBuilder<>* ir_builder) - : ForLoopNest(/*name=*/"", ir_builder) {} + explicit ForLoopNest(llvm::IRBuilder<>* ir_builder, + llvm::Type* index_ty = nullptr) + : ForLoopNest(/*name=*/"", ir_builder) { + SetIndexType(index_ty); + } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) + ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder, + llvm::Type* index_ty = nullptr) : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), - ir_builder_(ir_builder) {} + ir_builder_(ir_builder) { + SetIndexType(index_ty); + } // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have @@ -252,6 +258,14 @@ class ForLoopNest { llvm::BasicBlock* GetInnerLoopBodyBasicBlock() { return inner_loop_body_bb_; } private: + void SetIndexType(llvm::Type* index_ty) { + index_type_ = index_ty == nullptr ? ir_builder_->getInt64Ty() : index_ty; + } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + return llvm::ConstantInt::get(index_type_, c); + } + // Human-friendly name of the loop nest. string name_; @@ -266,6 +280,8 @@ class ForLoopNest { llvm::IRBuilder<>* ir_builder_; + llvm::Type* index_type_; + TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index d18c9dee826eab5760d391bb8f7b5bd02ab659ae..97bacc34b59118e60100e4749638d469a1ef1378 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -249,167 +250,14 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, return shape; } -namespace { - -// Recursively construct a multidimensional LLVM constant which represents the -// given literal. The minor-to-major dimension ordering in the constant matches -// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0 -// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a -// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type -// [4 x [3 x [2 x float]] will be returned. -// -// multi_index is a multidimensional index into the array. dimension_index is an -// index into the minor_to_major field in the literal shape. This determines -// which dimension is iterated over in this level of the recursion. Dimensions -// are iterated from most major down to most minor (highest dimension_index -// value down to zero). -llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, - std::vector* multi_index, - llvm::Module* module) { - const Shape& shape = literal.shape(); - llvm::Type* ir_element_type = - llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); - if (dimension_index == -1) { - // Base case of the recursion. Index into the data field of the protobuf - // with the multi index. - llvm::Constant* value; - switch (shape.element_type()) { - case PRED: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U8: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case F32: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case BF16: - value = llvm::ConstantInt::get( - ir_element_type, - tensorflow::bit_cast(literal.Get(*multi_index))); - break; - case F16: - value = llvm::ConstantFP::get( - ir_element_type, - static_cast(literal.Get(*multi_index))); - break; - case F64: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case C64: { - complex64 x = literal.Get(*multi_index); - value = llvm::ConstantStruct::get( - static_cast(ir_element_type), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.real()), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.imag())); - break; - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } - return value; - } - - // The dimension index starts at the one less than the rank of the array and - // decrements with each recursive call. We want to iterate through the - // dimensions in major-to-minor order as we recurse so just index into - // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); - - // Recursively call LiteralToConstant to construct subarrays for the - // more-minor dimensions. Gather the subarrays into a vector for bundling into - // a new (higher-dimensional) ConstantArray. - std::vector elements; - for (int64 i = 0; i < shape.dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - elements.push_back( - LiteralToConstant(literal, dimension_index - 1, multi_index, module)); - } - - llvm::Type* element_type; - if (elements.empty()) { - element_type = ir_element_type; - for (int i = 0; i < dimension_index; ++i) { - int64 index = LayoutUtil::Minor(shape.layout(), i); - element_type = - llvm::ArrayType::get(element_type, shape.dimensions(index)); - } - } else { - element_type = elements[0]->getType(); - } - llvm::ArrayType* aggregate_type = - llvm::ArrayType::get(element_type, shape.dimensions(dimension)); - return llvm::ConstantArray::get(aggregate_type, elements); -} - -template -llvm::Constant* GetConstantDataArray(const Literal& literal, - llvm::Module* module) { - const T* data = static_cast(literal.untyped_data()); - int64 num_elements = literal.size_bytes() / sizeof(T); - return llvm::ConstantDataArray::get(module->getContext(), - llvm::makeArrayRef(data, num_elements)); -} - -} // namespace - llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - const Shape& shape = literal.shape(); - // TODO(b/29904935): We can get rid of this switch by exposing a - // ConstantDataArray factory method that takes a llvm::Type and a StringRef. - switch (shape.element_type()) { - case U64: - return GetConstantDataArray(literal, module); - case U32: - return GetConstantDataArray(literal, module); - case U8: - return GetConstantDataArray(literal, module); - case S64: - return GetConstantDataArray(literal, module); - case S32: - return GetConstantDataArray(literal, module); - case F64: - return GetConstantDataArray(literal, module); - case F32: - return GetConstantDataArray(literal, module); - case BF16: - case F16: - return GetConstantDataArray(literal, module); - case PRED: - return GetConstantDataArray(literal, module); - // TODO(b/29904935): Also use ConstantDataArray for complex numbers. - case C64: { - int64 dimensions = ShapeUtil::Rank(shape); - std::vector multi_index(dimensions, 0); - return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, - &multi_index, module); - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } + const char* data = static_cast(literal.untyped_data()); + CHECK_EQ(module->getDataLayout().isLittleEndian(), + tensorflow::port::kLittleEndian); + return llvm::ConstantDataArray::getString( + module->getContext(), llvm::StringRef(data, literal.size_bytes()), + /*AddNull=*/false); } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index dc2934a34c23f8229947210cacc9863d47c2ea55..e8b0605b9d75677b34f0973d88d269a5795b7629 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -90,11 +90,12 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; - return {IrArray::Index()}; + return {IrArray::Index(index_type)}; } // Create loop nest with one for-loop for each dimension of the target shape. @@ -102,7 +103,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). ForLoopNest loop_nest(loop_name, ir_builder_); - IrArray::Index array_index(shape_.dimensions_size()); + IrArray::Index array_index(index_type, shape_.dimensions_size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( @@ -125,9 +126,14 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = ir_builder_->getInt64Ty(); + } + for (const IrArray::Index& array_index : - EmitIndexAndSetExitBasicBlock(loop_name)) { + EmitIndexAndSetExitBasicBlock(loop_name, index_type)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index b70d28ecd3033eb26629718e50ce48f39b162273..6be1c2fba2cbd78a02865901ef8c5b7e2b2a74e6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -65,13 +65,16 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/""); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", + ir_builder_->getInt64Ty()); } + virtual std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name); + tensorflow::StringPiece loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = "", + llvm::Type* index_type = nullptr); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index dacc54742c0897bbd92315f1e33a484aae56bb7f..3b298f4746d6177da52ba0227705d07fbeba5c19 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -45,7 +45,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Read start indices from start_indices_generator. const int64 rank = ShapeUtil::Rank(output_shape); - IrArray::Index start_index(rank); + IrArray::Index start_index(ir_builder->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); @@ -79,7 +79,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // // output_index[dim] = start_index[dim] + update_index[dim] // - IrArray::Index output_index(rank); + IrArray::Index output_index(start_index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 296d04d4362b12fdc39798a016ca9e8795e02586..53efc30c3653879709fceae3dcdd4f679740f622 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -154,7 +154,8 @@ StatusOr> LocalService::CompileExecutable( for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { tensorflow::gtl::optional metadata = ParameterMetadata(computation, /*parameter_number=*/i); @@ -178,8 +179,8 @@ StatusOr> LocalService::CompileExecutable( } } if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape.result())); + TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(), + program_shape.result())); } ExecutionOptions execution_options = @@ -189,6 +190,9 @@ StatusOr> LocalService::CompileExecutable( std::unique_ptr module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options)); + VLOG(3) << "Computation Layout: " + << module_config->entry_computation_layout().ToString(); + TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index f9f9c7dcf788c24468cd474d9e7e20980135c1f0..4166ef5baf9c891968b584a0c498005e9ae87784 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -28,7 +28,7 @@ StatusOr MultiOutputFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { computation_ = computation; - reachability_ = computation_->ComputeReachability(); + RecomputeReachability(); candidates_.clear(); candidates_index_.clear(); all_fusion_candidates_.clear(); @@ -115,39 +115,18 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, HloInstruction* fused = instr2; // Make sure that if only one of the instructions is a fusion, or if only one // of the instructions is a multi-output fusion, it's what will be fused into. - // - // An invariant is that no bitcast nodes will show up in the middle of a - // fusion node. This invariant must hold in order for us to lower it. Given - // that, we require that during multi-output fusion, a fusion node ending with - // bitcast to preserve its structure as a nested fusion instead being - // merged and flattened. - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { std::swap(remaining, fused); } if (fused->IsMultiOutputFusion()) { std::swap(remaining, fused); } - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { remaining->MergeFusionInstructionIntoMultiOutput(fused); } else { - if (remaining->opcode() == HloOpcode::kFusion && - remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { - auto parent_computation = remaining->parent(); - // Create a nested fusion node. - auto remaining_nested_fused = - parent_computation->AddInstruction(HloInstruction::CreateFusion( - remaining->shape(), HloInstruction::FusionKind::kLoop, - remaining)); - TF_CHECK_OK(parent_computation->ReplaceInstruction( - remaining, remaining_nested_fused)); - remaining = remaining_nested_fused; - } remaining->FuseInstructionIntoMultiOutput(fused); } - return remaining; } @@ -277,6 +256,10 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, return true; } +void MultiOutputFusion::RecomputeReachability() { + reachability_ = computation_->ComputeReachability(); +} + void MultiOutputFusion::UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, tensorflow::gtl::ArraySlice instrs_to_update, @@ -345,14 +328,11 @@ bool MultiOutputFusion::Perform() { --fuel_; } } - if (DoProducerConsumerMultiOutputFusion(computation_)) { + if (DoProducerConsumerMultiOutputFusion()) { changed = true; } return changed; } -bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion( - HloComputation* /*computation*/) { - return false; -} +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index d9c36fa284347d1efa16d8d3e45da807c3b8bf8b..0019cd725417d81900974b462c3b05075ce3e893 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -78,6 +78,19 @@ class MultiOutputFusion : public HloPassInterface { // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Recompute reachability for the current computation. + void RecomputeReachability(); + + // Returns the reachability map for the current computation. + HloReachabilityMap* reachability() const { return reachability_.get(); } + + // Returns the computation for the pass. + HloComputation* computation() const { return computation_; } + // Update the reachability map after fusing instr1 and instr2. void UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, @@ -89,13 +102,9 @@ class MultiOutputFusion : public HloPassInterface { // // TODO(b/80420762): Perform producer-consumer multi-output fusion in // InstructionFusion instead. - virtual bool DoProducerConsumerMultiOutputFusion(HloComputation* computation); + virtual bool DoProducerConsumerMultiOutputFusion(); private: - // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. - // The other instruction is removed from its parent computation. - HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); - // Update the internal data structures after instr1 and instr2 are fused into // one fusion instruction. void Update(HloInstruction* instr1, HloInstruction* instr2); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 3a6a7c25f4b727c7112dbcbcb4f3d892679a0011..f6e7578a89551ec2f23d4d8c8b488c3c10e0bf1c 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -67,22 +67,17 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); - // Update count to at least the numeric suffix value to avoid future - // colisions with this name. - generated_names_[root] = std::max(generated_names_[root], numeric_suffix); } } - int64* count = &(generated_names_[root]); - if (*count == 0) { - *count = 1; + + SequentialIdGenerator& id_generator = generated_names_[root]; + numeric_suffix = id_generator.RegisterId(numeric_suffix); + if (numeric_suffix == 0) { return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) : root; - } else { - tensorflow::strings::StrAppend(&root, separator_, *count); - // Increment lookup under old 'root' name. - (*count)++; - return root; } + tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + return root; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4139c2700b25e8600182a034a8ac6f4f041c12e6..4423d6106920eaeab830bd9dc08529ff409a5161 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -17,10 +17,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_ #include -#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -44,13 +45,40 @@ class NameUniquer { static string GetSanitizedName(const string& name); private: + // Used to track and generate new identifiers for the same instruction name + // root. + class SequentialIdGenerator { + public: + SequentialIdGenerator() = default; + + // Tries to register id as used identifier. If id is not already used, the + // id itself will be returned. Otherwise a new one will be generated, and + // returned. + int64 RegisterId(int64 id) { + if (used_.insert(id).second) { + return id; + } + while (!used_.insert(next_).second) { + ++next_; + } + return next_++; + } + + private: + // The next identifier to be tried. + int64 next_ = 0; + + // Set of all the identifiers which has been used. + tensorflow::gtl::FlatSet used_; + }; + // The string to use to separate the prefix of the name from the uniquing // integer value. string separator_; - // Map from name prefix to the number of names generated using that prefix - // so far. - std::unordered_map generated_names_; + // Map from name prefix to the generator data structure which tracks used + // identifiers and generates new ones. + tensorflow::gtl::FlatMap generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 2ec255558c4ed3695ec6c824458cbedac44dc297..3e2592c6ac626143f1421e545a31d9be91e376bc 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -54,12 +54,13 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); - EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000")); - EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); - EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); + EXPECT_EQ("foo.55.0", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("bar.1000", uniquer.GetUniqueName("bar.1000")); + EXPECT_EQ("bar.2000", uniquer.GetUniqueName("bar.2000")); + EXPECT_EQ("bar.-2000", uniquer.GetUniqueName("bar.-2000")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.1")); } TEST_F(NameUniquerTest, PrefixHasSuffix) { @@ -77,12 +78,12 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); - EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_2", uniquer.GetUniqueName("foo")); // Invalid characters will be replaced with '_'. - EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000")); - EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); - EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); + EXPECT_EQ("bar_1000", uniquer.GetUniqueName("bar<1000")); + EXPECT_EQ("bar_2000", uniquer.GetUniqueName("bar<2000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. EXPECT_EQ("_10", uniquer.GetUniqueName( @@ -93,5 +94,15 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } +TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11")); + EXPECT_EQ("foo.10", uniquer.GetUniqueName("foo.10")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo.1")); + EXPECT_EQ("foo.12", uniquer.GetUniqueName("foo.12")); + EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 961158e677baa46465af3f1f9a62929d14547c30..da3b622bfae8ac5132f9f95070ee41674e79b5b8 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -64,25 +64,25 @@ namespace { // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, + se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); + transfer_manager->TransferLiteralFromDevice(stream, *argument)); *module->add_arguments() = literal->ToProto(); } return Status::OK(); } // Records the result of a computation in a HloSnapshot proto. -Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, +Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); + transfer_manager->TransferLiteralFromDevice(stream, result)); *module->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -191,21 +191,17 @@ Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, return Status::OK(); } -Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const { - if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { +Status Service::ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const { + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); + if (!ShapeUtil::Compatible(client_shape, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(client_shape).c_str(), ShapeUtil::HumanString(result_shape).c_str()); } - if (!LayoutUtil::HasLayout(shape_with_layout)) { - return InvalidArgument( - "Shape used to set computation result layout %s does not have layout", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); - } - return ShapeUtil::ValidateShape(shape_with_layout); + return Status::OK(); } StatusOr>> @@ -248,10 +244,8 @@ StatusOr> Service::CreateModuleConfig( tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); - ComputationLayout* host_computation_layout = - config->mutable_host_entry_computation_layout(); - ComputationLayout* device_computation_layout = - config->mutable_device_entry_computation_layout(); + ComputationLayout* computation_layout = + config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -268,32 +262,22 @@ StatusOr> Service::CreateModuleConfig( i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); - TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + *argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { const auto& shape_with_output_layout = execution_options->shape_with_output_layout(); - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, - program_shape.result())); TF_RETURN_IF_ERROR( - host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( - shape_with_output_layout)); + ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { // If the result layout is not set, then choose the default. - // TODO(b/29118294): Allow the compiler to choose a better layout in this - // case. - // TODO(b/78356948): We are forcing the default layout here. We should fix - // clients which expect a default layout, to be explicit about it, by - // passing the proper ExecutionOptions with shape_with_output_layout set. - host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); - device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -381,22 +365,6 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } -Status Service::ValidateEntryComputationLayout(HloModule* module) { - const ComputationLayout& on_device = - module->device_entry_computation_layout(); - for (int64 i = 0; i < on_device.parameter_count(); ++i) { - TF_RET_CHECK(ShapeUtil::Equal( - on_device.parameter_shape(i), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().parameter_shape(i)))); - } - TF_RET_CHECK(ShapeUtil::Equal( - module->device_entry_computation_layout().result_shape(), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().result_shape()))); - return Status::OK(); -} - StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, @@ -498,7 +466,7 @@ Service::ExecuteParallelAndRegisterResult( HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); TF_RETURN_IF_ERROR( - executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); + executable->PopulateExecutionProfile(&hlo_profile, stream)); XLA_LOG_LINES( tensorflow::INFO, hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); @@ -692,7 +660,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); + << module_config->entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -723,8 +691,10 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, for (int i = 0; i < executable_ptrs.size(); i++) { if (executable_ptrs[i]->dumping_snapshot()) { - TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), - all_executors[i][0], + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + all_executors[i][0]->device_ordinal())); + TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(), execute_backend_->transfer_manager(), executable_ptrs[i]->hlo_snapshot())); } @@ -749,7 +719,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, if (executable_ptrs[i]->dumping_snapshot()) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, allocation_tracker_.ResolveForReplica(outputs[i], 0)); - TF_RETURN_IF_ERROR(RecordResult(*result_buffer, all_executors[i][0], + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream(all_executors[i][0])); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), executable_ptrs[i]->hlo_snapshot())); // Dump out the ith snapshot. @@ -849,8 +821,6 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( @@ -897,12 +867,14 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + execute_backend_->default_stream_executor())); if (executable->dumping_snapshot()) { executable->hlo_snapshot()->set_execution_platform( execute_backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), + replicated_arguments.front(), stream.get(), execute_backend_->transfer_manager(), executable->hlo_snapshot())); } @@ -916,9 +888,9 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( const ShapedBuffer* result_buffer, allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->hlo_snapshot())); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + executable->hlo_snapshot())); TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } @@ -956,14 +928,13 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, return_shape = &shaped_buffer->on_host_shape(); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( + shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( std::unique_ptr result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( - executor, *shaped_buffer)); + stream.get(), *shaped_buffer)); if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal->shape())) { @@ -1013,9 +984,10 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, execute_backend_->transfer_manager()->AllocateScopedShapedBuffer( shape, execute_backend_->memory_allocator(), executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, shaped_buffer)); + stream.get(), *literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 8748a4c1447eca691abc0f7ca48feda48ceb86e1..47d196fb2aaee897ce1fd3745129af10bf5b2d2d 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -193,9 +193,6 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); - // Assert that host- and device-shapes are in a consistent state. - Status ValidateEntryComputationLayout(HloModule* module); - protected: friend class LocalExecutable; @@ -266,11 +263,11 @@ class Service : public ServiceInterface { // will be the result of this computation. Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given shape_with_layout + // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const; + Status ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index e25f5e67c719430c0e7a8e0bb059efdc01ea75f9..d05e995a95625f75ef4b694a1fbc8368ebed51c8 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -239,7 +239,6 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kNegate: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSign: - case HloOpcode::kSort: return shape; case HloOpcode::kNot: @@ -329,7 +328,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } -/* static */ StatusOr ShapeInference::InferGenerateTokenShape( +/* static */ StatusOr ShapeInference::InferAfterAllShape( tensorflow::gtl::ArraySlice arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { @@ -885,6 +884,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( @@ -939,6 +939,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, HloOpcode opcode, tensorflow::gtl::ArraySlice operands) { std::vector operand_shapes; + operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); } @@ -954,11 +955,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); + result.mutable_tuple_shapes()->Reserve(operand_shapes.size()); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); } return result; } + case HloOpcode::kSort: { + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } else if (operand_shapes.size() == 2) { + return ShapeUtil::MakeTupleShape( + {*operand_shapes[0], *operand_shapes[1]}); + } + return InvalidArgument("Unexpected number of operands for sort"); + } default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index eef6e62fc8d933452ebc3f9a5b8bc49828455be5..ad34a2aa184e786a9825193d23f106f8a950758a 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -216,11 +216,11 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); - // Infers the shape produced by a kGenerateToken operation. Trivially this - // shape is always a TOKEN shape. However, ShapeInference serves two purposes: - // inferring shapes and checking operand shapes. This method verifies that the - // operand shapes are all TOKENs. - static StatusOr InferGenerateTokenShape( + // Infers the shape produced by a kAfterAll. Trivially this shape is always a + // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes + // and checking operand shapes. This method verifies that the operand shapes + // are all TOKENs. + static StatusOr InferAfterAllShape( tensorflow::gtl::ArraySlice arg_shapes); // Helper that validates the given operand shape can be converted to the diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index c4d01562c4e32225ebb984d8fcd93ec3fa86e403..4c5038a009ba5da4172129980014913f3f4418f4 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -22,8 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/notification.h" + +using ::tensorflow::strings::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -36,8 +40,73 @@ TransferManager::GetPlatformTransferManagers() { return r; } +StatusOr> TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer) { + StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, + [&](StatusOr> arg) { + ret = std::move(arg); + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + +Status TransferManager::TransferLiteralToDevice( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR( + TransferLiteralToDeviceAsync(substream, literal, device_buffer)); + return substream->BlockHostUntilDone(); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + TransferArrayFromDevice(substream, shape, source, + [&](StatusOr> arg) { + ret = std::move(arg); + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + Status TransferManager::TransferArrayToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest)); + return substream->BlockHostUntilDone(); +} + +Status TransferManager::TransferArrayToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) @@ -51,28 +120,32 @@ Status TransferManager::TransferArrayToDevice( dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); - return TransferLiteralToDevice(executor, literal, shaped_buffer); + return TransferLiteralToDevice(stream, literal, shaped_buffer); } -StatusOr> TransferManager::TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, - const se::DeviceMemoryBase& source) { - TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) - << "Shape " << ShapeUtil::HumanString(shape) - << " has a differently shaped representation on-device: " - << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); +void TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, + std::function>)> done) { + if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { + auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), + " has a differently shaped representation on-device: ", + ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); + return done(FailedPrecondition("%s", error.c_str())); + } if (source.size() < GetByteSizeRequirement(shape)) { - return FailedPrecondition( - "Allocation on device not large enough for array: " - "%lld < %lld", - source.size(), GetByteSizeRequirement(shape)); + return done( + FailedPrecondition("Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(executor, shaped_buffer); + return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( @@ -108,10 +181,14 @@ StatusOr> TransferManager::TransferArrayFromDevice( } Status TransferManager::WriteTupleIndexTables( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables for " << device_buffer; + se::Stream* stream, const ShapedBuffer& device_buffer) { + TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); + return stream->BlockHostUntilDone(); +} - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); +Status TransferManager::WriteTupleIndexTablesAsync( + se::Stream* stream, const ShapedBuffer& device_buffer) { + VLOG(2) << "Writing tuple index tables for " << device_buffer; return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), @@ -129,7 +206,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteSingleTupleIndexTable(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(stream, elements, device_subshape, &device_memory); } @@ -138,26 +215,20 @@ Status TransferManager::WriteTupleIndexTables( } Status TransferManager::TransferBufferFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - int64 size, void* destination) { + se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, + void* destination) { if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " "%lld < %lld", source.size(), size); } - auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer from device to buffer"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } Status TransferManager::TransferBufferToDevice( - se::StreamExecutor* executor, int64 size, const void* source, + se::Stream* stream, int64 size, const void* source, se::DeviceMemoryBase* destination) { if (destination->size() < size) { return FailedPrecondition( @@ -165,13 +236,7 @@ Status TransferManager::TransferBufferToDevice( "%lld < %lld", destination->size(), size); } - auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of buffer to device"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 43a8092b06fba0e2495bce0ee1a309c85a908273..e384359642a8fe09e0b8516e342a56259912922a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -52,30 +52,65 @@ class TransferManager { return host_shape; } - // Returns a literal containing the data held in the given ShapedBuffer. - // using the provided executor. The optional literal_shape will be the shape - // for the literal. The shape of the ShapedBuffer and - // DeviceShape(literal_shape) must be compatible, but need not have the same - // layout. + // Returns a literal containing the data held in the given ShapedBuffer + // using the provided executor. This operation is performed synchronously + // without waiting for any other operation on a stream to complete. + // + // This function should be avoided in favor of the asynchronous version below. virtual StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; + se::Stream* stream, const ShapedBuffer& device_buffer); + + // Begins transferring a literal containing the data held in the given + // ShapedBuffer using the provided executor. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. 'done' is invoked with the result when + // complete. + // + // device_buffer is copied by reference and must live at least until done() is + // invoked. + virtual void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, - // but need not have the same layout - virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, + // but need not have the same layout. + // + // This operation is performed synchronously without waiting for any other + // operation on a stream to complete. This function should be avoided in favor + // of the asynchronous version below. + virtual Status TransferLiteralToDevice(se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) = 0; + const ShapedBuffer& device_buffer); + + // Transfers the given literal into the previously allocated device memory + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. + virtual Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, + Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); + void TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + std::function>)> done); + + Status TransferArrayToDeviceAsync(se::Stream* stream, + const LiteralSlice& literal, + const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, @@ -96,8 +131,10 @@ class TransferManager { // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. - Status WriteTupleIndexTables(se::StreamExecutor* executor, + Status WriteTupleIndexTables(se::Stream* stream, const ShapedBuffer& device_buffer); + Status WriteTupleIndexTablesAsync(se::Stream* stream, + const ShapedBuffer& device_buffer); // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory @@ -144,7 +181,7 @@ class TransferManager { // 'destination' buffer. // // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice(se::StreamExecutor* executor, + virtual Status TransferBufferFromDevice(se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination); @@ -152,15 +189,15 @@ class TransferManager { // destination of the device. // // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice(se::StreamExecutor* executor, - int64 size, const void* source, + virtual Status TransferBufferToDevice(se::Stream* stream, int64 size, + const void* source, se::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index eb6d1ada6b553f998fe06917dfdf0b5092cd79cd..d1e174464759dbc2c0d84c4ddac27cb21635e131 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -121,7 +122,6 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } namespace { - // Gather fusion instructions from 'instruction' into 'fusion_instructions'. void GatherFusionInstructions( HloInstruction* instruction, @@ -723,7 +723,8 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return false; } if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { // Loop fusion with kDynamicUpdateSlice fused root. @@ -732,6 +733,11 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( // 'operand_index', and this singleton use is the fused root at operand // index 0. return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } else { + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( + fusion_param); } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8831c513eee66e36163135b732f833d46cb7eb03..23519e445ea8a5f578a54708f38059feef3280c0 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -248,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -258,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); - + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, WhileLoopInvariantCodeMotion{}.Run(&module())); - EXPECT_FALSE(simplified_loop); + ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Outfeed())); @@ -287,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -297,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); HloInstruction* bitcast_inst = builder.AddInstruction( HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 619e87caa5b6d0f6ec3c3b1489b0d4f50ef29963..0536c99b671ff37d67bb0fc7f9ab0b806d15f016 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -208,8 +208,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); + while_body->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index d79d3297213e832306ea4726483b0f215df0f5d3..2ccb919acf9c4e7c59a1ebaf36f42a6781068b5e 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -179,7 +179,9 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - ROOT condition = pred[] infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) + ROOT condition = pred[] get-tuple-element(infeed), index=0 } ENTRY main { diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 18e54d23c241ae0d4c61d8be79ff021dfb02a3e6..4aacc87b78e2c271829cdf397cd69bfb490125b8 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -105,8 +105,8 @@ class ShapeTree { // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). - const T& element(const ShapeIndex& index) const; - T* mutable_element(const ShapeIndex& index); + const T& element(ShapeIndexView index) const; + T* mutable_element(ShapeIndexView index); // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } @@ -125,7 +125,7 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; } + bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; } ShapeTree(const ShapeTree&) = default; ShapeTree& operator=(const ShapeTree&) = default; @@ -211,12 +211,12 @@ class ShapeTree { // Returns an iterator pointing to the given ShapeIndex. // REQUIRES: index must exist in the ShapeTree. - iterator find(const ShapeIndex& index) { + iterator find(ShapeIndexView index) { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector::iterator(element), /*iterate_leaves_only=*/false); } - const_iterator find(const ShapeIndex& index) const { + const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector::const_iterator(element), @@ -285,8 +285,8 @@ class ShapeTree { static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. - Node* Lookup(const ShapeIndex& index); - const Node* Lookup(const ShapeIndex& index) const; + Node* Lookup(ShapeIndexView index); + const Node* Lookup(ShapeIndexView index) const; // The nodes in this shape tree. std::vector nodes_; @@ -463,17 +463,17 @@ ShapeTree::ShapeTree(const std::shared_ptr& shape, } template -const T& ShapeTree::element(const ShapeIndex& index) const { +const T& ShapeTree::element(ShapeIndexView index) const { return Lookup(index)->data.second; } template -T* ShapeTree::mutable_element(const ShapeIndex& index) { +T* ShapeTree::mutable_element(ShapeIndexView index) { return &Lookup(index)->data.second; } template -internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { +internal::ShapeTreeNode* ShapeTree::Lookup(ShapeIndexView index) { Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); @@ -485,7 +485,7 @@ internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { template const internal::ShapeTreeNode* ShapeTree::Lookup( - const ShapeIndex& index) const { + ShapeIndexView index) const { return const_cast(this)->Lookup(index); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index c85fb20e01c1c8b7a8fc0d2b10881e5f9feed977..2166c34358fa62815c3fb32f28392f9036e25158 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -24,9 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -93,8 +95,11 @@ bool IsArrayPrimitiveType(PrimitiveType primitive_type) { // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (!ShapeUtil::SameElementType(lhs, rhs)) { +bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, + bool ignore_fp_precision) { + if ((ignore_fp_precision && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { VLOG(3) << "CompareShapes: lhs element type != rhs element type"; return false; } @@ -102,7 +107,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (ShapeUtil::IsTuple(lhs)) { return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts); + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); }); } else if (!ShapeUtil::IsArray(lhs)) { // Non-tuple, non-array tupes such as opaque and token types are trivially @@ -169,7 +175,8 @@ StatusOr MakeShapeWithLayoutInternal( } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true); + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/false); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -178,6 +185,18 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/true); + if (!equal && VLOG_IS_ON(3)) { + VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " + << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); + } + + return equal; +} + /* static */ int64 ShapeUtil::Rank(const Shape& shape) { CHECK(ShapeUtil::IsArray(shape)) << "Non-arrays do not have a rank, shape: " << shape; @@ -263,6 +282,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( tensorflow::gtl::ArraySlice shapes) { Shape result; result.set_element_type(TUPLE); + result.mutable_tuple_shapes()->Reserve(shapes.size()); for (const auto& shape : shapes) { AppendShapeToTuple(shape, &result); } @@ -379,6 +399,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.tuple_shapes(index); } +/* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) { + int64 n = 0; + ForEachSubshape(shape, [&](const Shape& literal_subshape, + const ShapeIndex& index) { ++n; }); + return n; +} + /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); @@ -413,6 +440,18 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( std::multiplies()); } +/* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { + CHECK(IsArray(shape) || IsTuple(shape)); + if (IsArray(shape)) { + return ElementsIn(shape); + } + int64 count = 0; + for (const Shape& element_shape : shape.tuple_shapes()) { + count += ElementsInRecursive(element_shape); + } + return count; +} + /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } @@ -421,7 +460,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.element_type() == F32 && Rank(shape) == 0; } - namespace { // Class to memoize the computation of @@ -554,12 +592,11 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding // amount from our StringPiece type. + static LazyRE2 shape_pattern = { + "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume( - &s_consumable, - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?", - &element_type_string, &dimensions_string, &format_string, - &layout_string)) { + if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, + &dimensions_string, &format_string, &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); auto string_to_int64 = [&s](const string& input) -> StatusOr { @@ -645,7 +682,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return CompareShapes(lhs, rhs, /*compare_layouts=*/false); + return CompareShapes(lhs, rhs, /*compare_layouts=*/false, + /*ignore_fp_precision=*/false); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -847,6 +885,53 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } } + TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); + return Status::OK(); +} + +/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { + VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); + auto invalid_argument = + InvalidArgument("Shape %s size may overflow int64.", + ShapeUtil::HumanString(shape).c_str()); + if (!IsArray(shape)) { + return Status::OK(); + } + int64 shape_size; + if (LayoutUtil::IsSparseArray(shape)) { + shape_size = LayoutUtil::MaxSparseElements(shape.layout()); + if (shape_size < 0) { + return invalid_argument; + } + shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape)); + if (shape_size < 0) { + return invalid_argument; + } + shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64)); + if (shape_size < 0) { + return invalid_argument; + } + } + + // This is intentionally unconditional: even if the shape is sparse, we want + // to verify the densified version has a reasonable size. + if (shape.dimensions().empty()) { + return Status::OK(); + } + shape_size = 1; + for (int64 dim : shape.dimensions()) { + shape_size = MultiplyWithoutOverflow(shape_size, dim); + if (shape_size < 0) { + return invalid_argument; + } + } + shape_size = MultiplyWithoutOverflow( + shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + if (shape_size < 0) { + return invalid_argument; + } + + VLOG(3) << "Shape size is valid: " << shape_size; return Status::OK(); } @@ -946,6 +1031,11 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return leaves; } +/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { + CHECK(ShapeUtil::IsArray(shape)); + return ArrayContains(AsInt64Slice(shape.dimensions()), 1); +} + namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8ee3f490a0837ec363758f6c633d73aa57687db4..5ae04451d32bd733dce55c4a56f5ebc1882d9fbd 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -175,6 +175,9 @@ class ShapeUtil { // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); + // As ElementsIn(), but recurses through tuples. + static int64 ElementsInRecursive(const Shape& shape); + // Returns true if 'shape' is an array with zero elements. static bool IsZeroElementArray(const Shape& shape); @@ -277,6 +280,9 @@ class ShapeUtil { // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); + // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. + static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns the rank (number of dimensions) of the given shape. // Precondition: !IsTuple(shape) static int64 Rank(const Shape& shape); @@ -457,6 +463,9 @@ class ShapeUtil { // Precondition: IsTuple(shape) && TupleElementCount(shape) > index static const Shape& GetTupleElementShape(const Shape& shape, int64 index); + // Returns the number of elements, recursively, in the given shape. + static int64 SubshapeCount(const Shape& shape); + // Slices tuple elements in the range [start, limit) and returns a new tuple // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); @@ -516,6 +525,10 @@ class ShapeUtil { static Status ForEachMutableSubshapeWithStatus( Shape* shape, const MutatingStatusVisitorFunction& func); + // Returns true if `shape` (which must be an array) with degenerate dimensions + // (dimensions with bound 1). + static bool HasDegenerateDimensions(const Shape& shape); + // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i] static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, @@ -689,6 +702,10 @@ class ShapeUtil { static size_t Hash(const Shape& shape); private: + // Validates the shape size is sane. This makes sure it's safe to do + // calculations in int64 without overflowing. + static Status ValidateShapeSize(const Shape& shape); + // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 61aa198e524373f84b7e950d5835dd2457c88a62..b6f30af381dd8d24ff28fdf7f729d6cb3df46ec9 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -242,6 +242,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, EqualIgnoringFpPrecision) { + EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1}))); +} + +TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); @@ -792,6 +810,17 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(ShapeUtilTest, HasDegenerateDimensions) { + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 2}))); + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 1}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5}))); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index 0e1387c93938fa520562fcd63ac107a82b089a51..a32e2ad9851b0b5644f7e6f0f9ead6c438934c07 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -12,297 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -// StatusOr is the union of a Status object and a T object. StatusOr models -// the concept of an object that is either a value, or an error Status -// explaining why such a value is not present. To this end, StatusOr does not -// allow its Status value to be Status::OK. -// -// The primary use-case for StatusOr is as the return value of a -// function which may fail. -// -// Example client usage for a StatusOr, where T is not a pointer: -// -// StatusOr result = DoBigCalculationThatCouldFail(); -// if (result.ok()) { -// float answer = result.ValueOrDie(); -// printf("Big calculation yielded: %f", answer); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr: -// -// StatusOr result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr>: -// -// StatusOr> result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo = std::move(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example factory implementation returning StatusOr: -// -// StatusOr FooFactory::MakeNewFoo(int arg) { -// if (arg <= 0) { -// return tensorflow::InvalidArgument("Arg must be positive"); -// } else { -// return new Foo(arg); -// } -// } -// -// Note that the assignment operators require that destroying the currently -// stored value cannot invalidate the argument; in other words, the argument -// cannot be an alias for the current value, or anything owned by the current -// value. #ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_H_ #define TENSORFLOW_COMPILER_XLA_STATUSOR_H_ #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/statusor_internals.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { -#if defined(__clang__) -// Only clang supports warn_unused_result as a type annotation. -template -class TF_MUST_USE_RESULT StatusOr; -#endif - -template -class StatusOr : private internal_statusor::StatusOrData, - private internal_statusor::TraitsBase< - std::is_copy_constructible::value, - std::is_move_constructible::value> { - template - friend class StatusOr; - - typedef internal_statusor::StatusOrData Base; - - public: - typedef T element_type; - - // Constructs a new StatusOr with Status::UNKNOWN status. This is marked - // 'explicit' to try to catch cases like 'return {};', where people think - // StatusOr> will be initialized with an empty vector, - // instead of a Status::UNKNOWN status. - explicit StatusOr(); - - // StatusOr will be copy constructible/assignable if T is copy - // constructible. - StatusOr(const StatusOr&) = default; - StatusOr& operator=(const StatusOr&) = default; - - // StatusOr will be move constructible/assignable if T is move - // constructible. - StatusOr(StatusOr&&) = default; - StatusOr& operator=(StatusOr&&) = default; - - // Conversion copy/move constructor, T must be convertible from U. - template ::value>::type* = nullptr> - StatusOr(const StatusOr& other); - template ::value>::type* = nullptr> - StatusOr(StatusOr&& other); - - // Conversion copy/move assignment operator, T must be convertible from U. - template ::value>::type* = nullptr> - StatusOr& operator=(const StatusOr& other); - template ::value>::type* = nullptr> - StatusOr& operator=(StatusOr&& other); - - // Constructs a new StatusOr with the given value. After calling this - // constructor, calls to ValueOrDie() will succeed, and calls to status() will - // return OK. - // - // NOTE: Not explicit - we want to use StatusOr as a return type - // so it is convenient and sensible to be able to do 'return T()' - // when the return type is StatusOr. - // - // REQUIRES: T is copy constructible. - StatusOr(const T& value); - - // Constructs a new StatusOr with the given non-ok status. After calling - // this constructor, calls to ValueOrDie() will CHECK-fail. - // - // NOTE: Not explicit - we want to use StatusOr as a return - // value, so it is convenient and sensible to be able to do 'return - // Status()' when the return type is StatusOr. - // - // REQUIRES: !status.ok(). This requirement is DCHECKed. - // In optimized builds, passing Status::OK() here will have the effect - // of passing tensorflow::error::INTERNAL as a fallback. - StatusOr(const Status& status); - StatusOr& operator=(const Status& status); - - // TODO(b/62186997): Add operator=(T) overloads. - - // Similar to the `const T&` overload. - // - // REQUIRES: T is move constructible. - StatusOr(T&& value); - - // RValue versions of the operations declared above. - StatusOr(Status&& status); - StatusOr& operator=(Status&& status); - - // Returns this->status().ok() - bool ok() const { return this->status_.ok(); } - - // Returns a reference to our status. If this contains a T, then - // returns Status::OK(). - const Status& status() const &; - Status status() &&; - - // Returns a reference to our current value, or CHECK-fails if !this->ok(). - // - // Note: for value types that are cheap to copy, prefer simple code: - // - // T value = statusor.ValueOrDie(); - // - // Otherwise, if the value type is expensive to copy, but can be left - // in the StatusOr, simply assign to a reference: - // - // T& value = statusor.ValueOrDie(); // or `const T&` - // - // Otherwise, if the value type supports an efficient move, it can be - // used as follows: - // - // T value = std::move(statusor).ValueOrDie(); - // - // The std::move on statusor instead of on the whole expression enables - // warnings about possible uses of the statusor object after the move. - // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 - // See go/ref-qualifiers for more details on such overloads. - const T& ValueOrDie() const &; - T& ValueOrDie() &; - const T&& ValueOrDie() const &&; - T&& ValueOrDie() &&; - - T ConsumeValueOrDie() { return std::move(ValueOrDie()); } - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; -}; - -//////////////////////////////////////////////////////////////////////////////// -// Implementation details for StatusOr - -template -StatusOr::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {} - -template -StatusOr::StatusOr(const T& value) : Base(value) {} - -template -StatusOr::StatusOr(const Status& status) : Base(status) {} - -template -StatusOr& StatusOr::operator=(const Status& status) { - this->Assign(status); - return *this; -} - -template -StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} - -template -StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} - -template -StatusOr& StatusOr::operator=(Status&& status) { - this->Assign(std::move(status)); - return *this; -} - -template -template ::value>::type*> -inline StatusOr::StatusOr(const StatusOr& other) - : Base(static_cast::Base&>(other)) {} - -template -template ::value>::type*> -inline StatusOr& StatusOr::operator=(const StatusOr& other) { - if (other.ok()) - this->Assign(other.ValueOrDie()); - else - this->Assign(other.status()); - return *this; -} - -template -template ::value>::type*> -inline StatusOr::StatusOr(StatusOr&& other) - : Base(static_cast::Base&&>(other)) {} - -template -template ::value>::type*> -inline StatusOr& StatusOr::operator=(StatusOr&& other) { - if (other.ok()) { - this->Assign(std::move(other).ValueOrDie()); - } else { - this->Assign(std::move(other).status()); - } - return *this; -} - -template -const Status& StatusOr::status() const & { - return this->status_; -} -template -Status StatusOr::status() && { - return ok() ? Status::OK() : std::move(this->status_); -} - -template -const T& StatusOr::ValueOrDie() const & { - this->EnsureOk(); - return this->data_; -} - -template -T& StatusOr::ValueOrDie() & { - this->EnsureOk(); - return this->data_; -} - -template -const T&& StatusOr::ValueOrDie() const && { - this->EnsureOk(); - return std::move(this->data_); -} - -template -T&& StatusOr::ValueOrDie() && { - this->EnsureOk(); - return std::move(this->data_); -} - +// Use steam_executor's StatusOr so we don't duplicate code. template -void StatusOr::IgnoreError() const { - // no-op -} +using StatusOr = ::stream_executor::port::StatusOr; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e7e0a19db0516e4210f6bb78d6b5e6968bf78b2a..20b2885e90d8ae087bb1b49cfdfd757f51ddae73 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -697,6 +697,7 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1248,6 +1249,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1986,6 +1988,7 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", ], ) @@ -2037,6 +2040,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c3a289ee09cc1ee7b9d705a38c26a3ac7a8a6aa2..3bdf98544affca11fd825e28d20f4903188fe920 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -51,16 +51,16 @@ class ArrayElementwiseOpTestParamCount XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Neg(a); + auto a = ConstantR1(&builder, {}); + Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - builder.Neg(a); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); @@ -68,10 +68,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 1, 324, - std::numeric_limits::min(), - std::numeric_limits::max()}); - builder.Neg(a); + auto a = ConstantR1(&builder, + {-1, 0, 1, 324, std::numeric_limits::min(), + std::numeric_limits::max()}); + Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it @@ -84,17 +84,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Neg(a); + auto a = ConstantR1(&builder, {}); + Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); - builder.Neg(a); + auto a = ConstantR1( + &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); + Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, @@ -103,16 +103,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({ - -1, - 1, - 0, - 0x12345678, - static_cast(0xffffffff12345678l), - static_cast(0x8000000000000000LL), - static_cast(0x8000000000000001LL), - }); - builder.Neg(a); + auto a = + ConstantR1(&builder, { + -1, + 1, + 0, + 0x12345678, + static_cast(0xffffffff12345678l), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + }); + Neg(a); LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); ComputeAndCompareR1(&builder, @@ -130,8 +131,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.IsFinite(a); + auto a = ConstantR1(&builder, {}); + IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } @@ -141,21 +142,21 @@ static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { XlaBuilder builder(TestName()); - builder.IsFinite(builder.ConstantR0(NAN)); + IsFinite(ConstantR0(&builder, NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + IsFinite(ConstantR0(&builder, kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); - builder.IsFinite(builder.ConstantR0(inf)); + IsFinite(ConstantR0(&builder, inf)); ComputeAndCompareR0(&builder, false, {}); - builder.IsFinite(builder.ConstantR0(-inf)); + IsFinite(ConstantR0(&builder, -inf)); ComputeAndCompareR0(&builder, false, {}); - builder.IsFinite(builder.ConstantR0(0.0f)); + IsFinite(ConstantR0(&builder, 0.0f)); ComputeAndCompareR0(&builder, true, {}); } @@ -163,9 +164,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - auto a = builder.ConstantR1( - {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); - builder.IsFinite(a); + auto a = ConstantR1(&builder, + {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); + IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); @@ -173,9 +174,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); @@ -183,20 +184,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); - auto b = builder.ConstantR1( - {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); - builder.Add(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); + auto b = ConstantR1( + &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); + Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, @@ -205,9 +206,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -225,7 +226,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 1}; std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); - auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -239,11 +240,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 1, 0x8000000000000000LL}; std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); - auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - b.Add(lhs_param, rhs_param); + Add(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -265,7 +266,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0, -1}; std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); - auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -278,11 +279,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); - auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - auto sub = b.Sub(lhs_param, rhs_param); + Sub(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -305,23 +306,23 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { std::unique_ptr a_literal = Literal::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a_constant = builder.ConstantR1(a_values); - auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); + auto a_constant = ConstantR1(&builder, a_values); + auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); std::unique_ptr b_literal = Literal::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); - auto b_param = builder.ConstantR1(b_values); + auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + auto b_param = ConstantR1(&builder, b_values); - auto sum1 = builder.Add(a_constant, b_constant); - auto sum2 = builder.Add(a_constant, b_param); - auto sum3 = builder.Add(a_param, b_constant); - auto sum4 = builder.Add(a_param, b_param); + auto sum1 = Add(a_constant, b_constant); + auto sum2 = Add(a_constant, b_param); + auto sum3 = Add(a_param, b_constant); + auto sum4 = Add(a_param, b_param); - auto sum = builder.Add(sum1, sum2); - sum = builder.Add(sum, sum3); - sum = builder.Add(sum, sum4); + auto sum = Add(sum1, sum2); + sum = Add(sum, sum3); + sum = Add(sum, sum4); std::vector expected; for (int64 i = 0; i < count; ++i) { @@ -334,9 +335,9 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); @@ -344,38 +345,38 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); - auto b = builder.ConstantR1({-1, 2, 1, -1}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {-1, 0, 2, 1000000000}); + auto b = ConstantR1(&builder, {-1, 2, 1, -1}); + Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); - auto b = builder.ConstantR1( - {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); + auto b = ConstantR1( + &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); + Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, @@ -384,18 +385,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); + Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); @@ -403,9 +404,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -442,7 +443,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR1(&builder, quotients, {dividend_data.get(), divisor_data.get()}); @@ -454,7 +455,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Div(dividend, builder.ConstantR1(divisors)); + Div(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); } @@ -467,7 +468,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR1(&builder, remainders, {dividend_data.get(), divisor_data.get()}); @@ -479,7 +480,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Rem(dividend, builder.ConstantR1(divisors)); + Rem(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); } @@ -513,7 +514,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR1(&builder, quotients, {dividend_data.get(), divisor_data.get()}); @@ -524,7 +525,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Div(dividend, builder.ConstantR1(divisors)); + Div(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); } @@ -537,7 +538,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR1(&builder, remainders, {dividend_data.get(), divisor_data.get()}); @@ -548,7 +549,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Rem(dividend, builder.ConstantR1(divisors)); + Rem(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); } @@ -556,11 +557,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); - auto b = builder.ConstantR1( - {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); - builder.Div(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); + auto b = ConstantR1(&builder, + {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); + Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); @@ -568,20 +569,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); - auto b = builder.ConstantR1( - {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); - builder.Rem(a, b); + auto a = ConstantR1( + &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); + auto b = ConstantR1( + &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); + Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, @@ -590,20 +591,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Rem(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); - auto b = builder.ConstantR1( - {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); - builder.Rem(a, b); + auto a = ConstantR1( + &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); + auto b = ConstantR1( + &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); + Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, @@ -612,9 +613,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); @@ -622,9 +623,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -648,18 +649,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(a_data); - auto b = builder.ConstantR1(b_data); - builder.Mul(a, b); + auto a = ConstantR1(&builder, a_data); + auto b = ConstantR1(&builder, b_data); + Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } @@ -679,20 +680,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(a_data); - auto b = builder.ConstantR1(b_data); - builder.Mul(a, b); + auto a = ConstantR1(&builder, a_data); + auto b = ConstantR1(&builder, b_data); + Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); - auto b = builder.ConstantR1( - {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); - builder.Mul(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); + auto b = ConstantR1(&builder, + {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); + Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, @@ -701,27 +702,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false, true, true}); - auto b = builder.ConstantR1({false, true, false, true}); - builder.And(a, b); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, false}, {true, true}}); - auto b = builder.ConstantR2({{false, true}, {false, true}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -729,27 +730,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, -1, -8}); - auto b = builder.ConstantR1({5, -7, 12}); - builder.And(a, b); + auto a = ConstantR1(&builder, {0, -1, -8}); + auto b = ConstantR1(&builder, {5, -7, 12}); + And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); - auto b = builder.ConstantR2({{1, -6}, {4, 5}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{0, -5}, {-1, 5}}); + auto b = ConstantR2(&builder, {{1, -6}, {4, 5}}); + And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -757,27 +758,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 1, 8}); - auto b = builder.ConstantR1({5, 7, 12}); - builder.And(a, b); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 12}); + And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 1}, {3, 8}}); - auto b = builder.ConstantR2({{1, 0}, {7, 6}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{0, 1}, {3, 8}}); + auto b = ConstantR2(&builder, {{1, 0}, {7, 6}}); + And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -785,27 +786,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false, true, true}); - auto b = builder.ConstantR1({false, true, false, true}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, false}, {true, true}}); - auto b = builder.ConstantR2({{false, true}, {false, true}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -813,27 +814,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, -1, 8}); - auto b = builder.ConstantR1({5, -7, 4}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {0, -1, 8}); + auto b = ConstantR1(&builder, {5, -7, 4}); + Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, -1}, {8, 8}}); - auto b = builder.ConstantR2({{5, -7}, {4, 1}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); + Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -841,27 +842,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 1, 8}); - auto b = builder.ConstantR1({5, 7, 4}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 4}); + Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 1}, {8, 8}}); - auto b = builder.ConstantR2({{5, 7}, {4, 1}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); + Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -869,25 +870,108 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {false, true, true, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + Xor(a, b); + + Array2D expected_array({{false, true}, {true, false}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, -1, 8}); + auto b = ConstantR1(&builder, {5, -7, 4}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {5, 6, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); + Xor(a, b); + + Array2D expected_array({{5, 6}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 4}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {5, 6, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); + Xor(a, b); + + Array2D expected_array({{5, 6}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, true, true, false}); - builder.Not(a); + auto a = ConstantR1(&builder, {false, true, true, false}); + Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, true}, {true, false}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{false, true}, {true, false}}); + Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -895,24 +979,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 1}); - builder.Not(a); + auto a = ConstantR1(&builder, {-1, 0, 1}); + Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{-1, 0}, {1, 8}}); + Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -920,24 +1004,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 4294967295}); - builder.Not(a); + auto a = ConstantR1(&builder, {0, 4294967295}); + Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{0, 4294967295}, {1, 4294967294}}); + Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -945,19 +1029,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x12345678), - static_cast(0xF0001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); - builder.ShiftLeft(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x12345678), static_cast(0xF0001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, -1}); + ShiftLeft(a, b); ComputeAndCompareR1(&builder, {static_cast(0x23456780), 0x00100000, 0x4, @@ -967,11 +1051,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); - builder.ShiftRightArithmetic(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x92345678), static_cast(0x10001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, -1}); + ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, @@ -982,11 +1066,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); - builder.ShiftRightLogical(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x92345678), static_cast(0x10001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, -1}); + ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -994,10 +1078,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); - builder.ShiftLeft(a, b); + auto a = ConstantR1(&builder, + {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u}); + ShiftLeft(a, b); ComputeAndCompareR1( &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); @@ -1005,10 +1089,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); - builder.ShiftRightArithmetic(a, b); + auto a = ConstantR1(&builder, + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u}); + ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); @@ -1016,10 +1100,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); - builder.ShiftRightLogical(a, b); + auto a = ConstantR1(&builder, + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u}); + ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -1028,18 +1112,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1047,9 +1131,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Ge(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } @@ -1057,9 +1141,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Gt(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } @@ -1067,9 +1151,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Le(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } @@ -1077,9 +1161,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Lt(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } @@ -1088,9 +1172,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Eq(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1099,9 +1184,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1109,26 +1194,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, - {1.0f, 25.5f}, - {2.25f, -3.0f}, - {NAN, 0.0f}, - {1.0f, 6.0f}}); - auto rhs = builder.ConstantR1({{0.0f, 10.0f}, - {1.0f, 5.0f}, - {2.25f, -3.0f}, - {10.0f, 0.0f}, - {1.0f, NAN}}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1138,17 +1223,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, - {1.0f, 25.5f}, - {2.25f, -3.0f}, - {NAN, 0.0f}, - {1.0f, 6.0f}}); - auto rhs = builder.ConstantR1({{0.0f, 10.0f}, - {1.0f, 5.0f}, - {2.25f, -3.0f}, - {10.0f, 0.0f}, - {1.0f, NAN}}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } @@ -1158,9 +1243,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN}); + Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } @@ -1169,9 +1254,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Ne(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1181,9 +1267,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Ge(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1193,9 +1280,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Gt(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1206,9 +1294,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Le(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1218,9 +1307,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Lt(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1230,9 +1320,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1242,9 +1332,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1253,9 +1343,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Ge(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1264,9 +1354,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Gt(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1276,9 +1366,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Le(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1287,9 +1377,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Lt(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1300,10 +1390,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = - builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); + ConstantR1(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = - builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); - builder.Pow(lhs, rhs); + ConstantR1(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); + Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); @@ -1312,9 +1402,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); - auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); - builder.Pow(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f}); + auto rhs = ConstantR1(&builder, {0.5f, 0.6f, -0.6f, -0.6f}); + Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); @@ -1322,9 +1412,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Pow(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1340,10 +1430,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); - auto sum = b.ConstantR0(0.0f); - auto param = b.Parameter(0, param_literal->shape(), "param"); + auto sum = ConstantR0(&b, 0.0f); + auto param = Parameter(&b, 0, param_literal->shape(), "param"); for (float exponent : exponents) { - sum = b.Add(sum, b.Pow(param, b.ConstantR0(exponent))); + sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } std::vector expected; @@ -1370,9 +1460,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Pow(b.Exp(param0), param1); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Pow(Exp(param0), param1); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1395,9 +1485,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Log(b.Pow(param0, param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Log(Pow(param0, param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1420,9 +1510,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Mul(b.Exp(param0), b.Exp(param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1445,9 +1535,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::unique_ptr literal1 = Literal::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Div(param0, b.Exp(param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Div(param0, Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1476,10 +1566,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::unique_ptr literal2 = Literal::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(b.Div(param0, param1), param2); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(Div(param0, param1), param2); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1509,10 +1599,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(param0, b.Div(param1, param2)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(param0, Div(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1542,10 +1632,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(param0, b.Pow(param1, param2)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1580,11 +1670,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::unique_ptr data3 = client_->TransferToServer(*literal3).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - auto param3 = b.Parameter(3, literal3->shape(), "param2"); - b.Div(b.Div(param0, param1), b.Div(param2, param3)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1604,8 +1694,8 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } - auto x = builder.ConstantR1(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR1(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); std::vector expected; expected.reserve(values.size()); @@ -1630,8 +1720,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); - auto x = builder.ConstantR4FromArray4D(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR4FromArray4D(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -1641,8 +1731,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); - auto x = builder.ConstantR4FromArray4D(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR4FromArray4D(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -1650,9 +1740,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, error_spec_); @@ -1660,18 +1750,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, error_spec_); @@ -1680,9 +1770,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, error_spec_); @@ -1690,18 +1780,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, error_spec_); @@ -1711,11 +1801,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); - auto y = builder.ConstantR1( - {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = ConstantR1( + &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + Max(x, y); std::vector expected = {min, max, 0, -1, 0, 0, 0, 1, 1, 10, max, max, max}; @@ -1726,11 +1816,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); - auto y = builder.ConstantR1( - {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); - builder.Min(x, y); + auto x = ConstantR1( + &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = ConstantR1( + &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + Min(x, y); std::vector expected = {min, min, min, -10, -1, -1, 0, 0, 0, 1, 0, max, min}; @@ -1740,9 +1830,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); - auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); - builder.Max(x, y); + auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); + auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); + Max(x, y); std::vector expected = {0, 1, 1, 1, 10, max, max, max}; ComputeAndCompareR1(&builder, expected, {}); @@ -1751,9 +1841,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); - auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); - builder.Min(x, y); + auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); + auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); + Min(x, y); std::vector expected = {0, 0, 0, 1, 1, 0, 234234, max}; ComputeAndCompareR1(&builder, expected, {}); @@ -1761,11 +1851,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = builder.ConstantR1( - {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; @@ -1774,9 +1864,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XlaBuilder builder(TestName()); - auto u = builder.ConstantR1({3.5}); - auto v = builder.ConstantR1({}); - builder.Max(u, v); + auto u = ConstantR1(&builder, {3.5}); + auto v = ConstantR1(&builder, {}); + Max(u, v); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1784,9 +1874,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { XlaBuilder builder(TestName()); - auto u = builder.ConstantR1({3.5}); - auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); - builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); + auto u = ConstantR1(&builder, {3.5}); + auto v = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); } @@ -1794,10 +1884,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - builder.Max(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1805,9 +1895,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({}); - auto m = builder.ConstantR2({{}, {}}); - builder.Max(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {}); + auto m = ConstantR2(&builder, {{}, {}}); + Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1815,10 +1905,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { XlaBuilder builder(TestName()); - auto scalar = builder.ConstantR0(2); + auto scalar = ConstantR0(&builder, 2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); - auto array = builder.ConstantR3FromArray3D(a_3d); - builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + auto array = ConstantR3FromArray3D(&builder, a_3d); + Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); ComputeAndCompareR3(&builder, expected, {}); @@ -1826,10 +1916,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { XlaBuilder builder(TestName()); - auto scalar = builder.ConstantR0(2); + auto scalar = ConstantR0(&builder, 2); Array3D a_3d(2, 0, 3); - auto array = builder.ConstantR3FromArray3D(a_3d); - builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + auto array = ConstantR3FromArray3D(&builder, a_3d); + Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected(2, 0, 3); ComputeAndCompareR3(&builder, expected, {}); @@ -1837,10 +1927,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XlaBuilder builder(TestName()); - auto m = - builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); - auto v = builder.ConstantR1({-10.2f, 16.4f}); - builder.Min(m, v, /*broadcast_dimensions=*/{0}); + auto m = ConstantR2(&builder, + {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); + auto v = ConstantR1(&builder, {-10.2f, 16.4f}); + Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1848,9 +1938,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { XlaBuilder builder(TestName()); - auto m = builder.ConstantR2({{}, {}}); - auto v = builder.ConstantR1({-10.2f, 16.4f}); - builder.Min(m, v, /*broadcast_dimensions=*/{0}); + auto m = ConstantR2(&builder, {{}, {}}); + auto v = ConstantR1(&builder, {-10.2f, 16.4f}); + Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1859,11 +1949,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { XlaBuilder builder(TestName()); auto array2d = - builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); - auto array4d = builder.ConstantR4FromArray4D( - {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, - {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); - builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + auto array4d = ConstantR4FromArray4D( + &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, + {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); + Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected( {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, @@ -1874,10 +1964,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { XlaBuilder builder(TestName()); auto array2d = - builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); - auto array4d = builder.ConstantR4FromArray4D(arg); - builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + auto array4d = ConstantR4FromArray4D(&builder, arg); + Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected(2, 2, 0, 3); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -1885,9 +1975,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); - builder.Min(x, y); + auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + Min(x, y); std::vector expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; ComputeAndCompareR1(&builder, expected, {}); @@ -1895,9 +1985,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); - builder.Max(x, y); + auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + Max(x, y); std::vector expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; ComputeAndCompareR1(&builder, expected, {}); @@ -1905,19 +1995,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); - auto b = builder.ConstantR1({10, 5, 1, 10, -10}); - builder.Rem(a, b); + auto a = ConstantR1(&builder, {-3, 26, 2, -1, 1}); + auto b = ConstantR1(&builder, {10, 5, 1, 10, -10}); + Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XlaBuilder builder(TestName()); - auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); - auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); - auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); - builder.Clamp(minimum, argument, maximum); + auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto argument = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); @@ -1925,10 +2016,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); - auto minimum = builder.ConstantR0(0.0f); - auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto maximum = builder.ConstantR0(5.0f); - builder.Clamp(minimum, argument, maximum); + auto minimum = ConstantR0(&builder, 0.0f); + auto argument = ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto maximum = ConstantR0(&builder, 5.0f); + Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); @@ -1936,16 +2027,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0.0f); - auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); - auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto max_scalar = builder.ConstantR0(3.0f); - auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + auto min_scalar = ConstantR0(&builder, 0.0f); + auto min_vector = + ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto arg_vector = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto max_scalar = ConstantR0(&builder, 3.0f); + auto max_vector = + ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); @@ -1953,52 +2047,52 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { XlaBuilder builder(TestName()); - auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); - auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); - auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); - builder.Clamp(min_vector, arg_vector, max_vector); + auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0, -5}); + auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4, 10}); + auto max_vector = ConstantR1(&builder, {3, 0, 25, 5, 123, -1}); + Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0); - auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); - auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); - auto max_scalar = builder.ConstantR0(3); - auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + auto min_scalar = ConstantR0(&builder, 0); + auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0}); + auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4}); + auto max_scalar = ConstantR0(&builder, 3); + auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { XlaBuilder builder(TestName()); - auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); - auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); - auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); - builder.Clamp(min_vector, arg_vector, max_vector); + auto min_vector = ConstantR1(&builder, {1, 2, 1, 2, 0, ~0u - 4}); + auto arg_vector = ConstantR1(&builder, {2, 10, 5, 1, 4, 10}); + auto max_vector = ConstantR1(&builder, {3, 5, 25, 5, 123, ~0u}); + Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0); - auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); - auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); - auto max_scalar = builder.ConstantR0(3); - auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + auto min_scalar = ConstantR0(&builder, 0); + auto min_vector = ConstantR1(&builder, {1, 0, 1, 2, 0}); + auto arg_vector = ConstantR1(&builder, {2, 10, 0, 1, 4}); + auto max_scalar = ConstantR0(&builder, 3); + auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } @@ -2016,9 +2110,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, @@ -2038,9 +2132,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( @@ -2055,9 +2149,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); - auto p = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Add(a, p); + auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); + auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); @@ -2065,8 +2159,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - builder.Cos(a); + auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); + Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); @@ -2074,8 +2168,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - builder.Sin(a); + auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); + Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); @@ -2083,9 +2177,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); - auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); - builder.Atan2(a, b); + auto a = ConstantR1(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); + auto b = ConstantR1(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); + Atan2(a, b); ComputeAndCompareR1( &builder, @@ -2095,8 +2189,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); - builder.Tanh(a); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f}); + Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); @@ -2118,8 +2212,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Tanh(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Tanh(input); ComputeAndCompareR1( &builder, @@ -2164,8 +2258,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Exp(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Exp(input); std::vector expected_result; int64 input_size = input_literal->shape().dimensions(0); @@ -2202,8 +2296,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Log(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Log(input); std::vector expected_result; int64 input_size = input_literal->shape().dimensions(0); @@ -2218,9 +2312,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); - builder.Clz(a); + auto a = ConstantR1( + &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); + Clz(a); ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); } @@ -2228,8 +2322,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { XlaBuilder builder(TestName()); auto a = - builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); - builder.Clz(a); + ConstantR1(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + Clz(a); ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); } @@ -2241,12 +2335,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // c---------------------/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); - auto add = builder.Add(a, b); - builder.Add(add, c); + auto add = Add(a, b); + Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2259,12 +2353,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // a---------------------/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); - auto add = builder.Add(b, c); - builder.Add(a, add); + auto add = Add(b, c); + Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2276,12 +2370,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // b ----- (neg) ----/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); - auto neg_a = builder.Neg(a); - auto neg_b = builder.Neg(b); - builder.Add(neg_a, neg_b); + auto neg_a = Neg(a); + auto neg_b = Neg(b); + Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); @@ -2297,14 +2391,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // d -----/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); - auto d = builder.ConstantR1({-19.0f, 10.0f, -40.0f, 20.2f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); + auto d = ConstantR1(&builder, {-19.0f, 10.0f, -40.0f, 20.2f}); - auto add_ab = builder.Add(a, b); - auto add_cd = builder.Add(c, d); - builder.Add(add_ab, add_cd); + auto add_ab = Add(a, b); + auto add_cd = Add(c, d); + Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); @@ -2312,11 +2406,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2326,10 +2420,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto scalar = builder.ConstantR0(3.0f); - builder.Add(scalar, a); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = ConstantR0(&builder, 3.0f); + Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2338,10 +2432,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto scalar = builder.ConstantR0(3.0f); - builder.Add(a, scalar); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = ConstantR0(&builder, 3.0f); + Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2351,13 +2445,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); + auto v = ConstantR1(&builder, {20.0f, 40.0f, 60.0f}); // clang-format off - auto m = builder.ConstantR2({ + auto m = ConstantR2(&builder, { {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on - builder.Add(v, m, /*broadcast_dimensions=*/{1}); + Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2366,14 +2460,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { // Test broadcasting in Eq comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({42, 73}); - auto m = builder.ConstantR2({{42, 73}, {42, 52}}); + auto v = ConstantR1(&builder, {42, 73}); + auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); // This test exercises both possible broadcast dimensions for a vector/matrix // comparison. - auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1}); - auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); - auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); + auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1}); + auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); + Tuple(&builder, {cmp_dim_0, cmp_dim_1}); auto expected = Literal::MakeTuple( {Literal::CreateR2({{true, true}, {true, false}}).get(), @@ -2384,9 +2478,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({42, 73}); - auto m = builder.ConstantR2({{42, 73}, {42, 52}}); - builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {42, 73}); + auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); + Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, @@ -2398,9 +2492,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, @@ -2412,9 +2506,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, @@ -2426,9 +2520,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Le(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, @@ -2440,9 +2534,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, @@ -2455,9 +2549,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. XlaBuilder builder(TestName()); - auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); - auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); - builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + auto m = + ConstantR2(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); + auto v = ConstantR1(&builder, {2.0f, 4.0f, 6.0f}); + Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2468,10 +2563,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); - builder.Add(m, md); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = ConstantR2(&builder, {{10.0f, 20.0f, 30.0f}}); + Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2483,10 +2578,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto md = builder.ConstantR2({{10.0f}, {20.0f}}); - builder.Add(m, md); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = ConstantR2(&builder, {{10.0f}, {20.0f}}); + Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2501,9 +2596,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. - auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); - auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - builder.Add(a, b); + auto a = ConstantR2(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}}); + auto b = ConstantR2(&builder, {{1.0f, 2.0f, 3.0f}}); + Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, @@ -2515,9 +2610,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f}); - auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - builder.Add(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {20.0f, 40.0f}); + auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); + Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2526,9 +2621,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f}); - auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - builder.Add(v, m, /*broadcast_dimensions=*/{0}); + auto v = ConstantR1(&builder, {20.0f, 40.0f}); + auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); + Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2538,12 +2633,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); - auto a = builder.ConstantR3FromArray3D(a_3d); + auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); - auto b = builder.ConstantR3FromArray3D(b_3d); - builder.Add(a, b); + auto b = ConstantR3FromArray3D(&builder, b_3d); + Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, @@ -2565,9 +2660,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { {11.0f, 12.0f}}, }); // clang-format on - auto a = builder.ConstantR3FromArray3D(a_3d); - auto v = builder.ConstantR1({10.0f, 20.0f}); - builder.Add(a, v, /*broadcast_dimensions=*/{2}); + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto v = ConstantR1(&builder, {10.0f, 20.0f}); + Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, @@ -2589,9 +2684,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { {11.0f, 12.0f}}, }); // clang-format on - auto a = builder.ConstantR3FromArray3D(a_3d); - auto v = builder.ConstantR1({10.0f, 20.0f}); - builder.Add(a, v, /*broadcast_dimensions=*/{0}); + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto v = ConstantR1(&builder, {10.0f, 20.0f}); + Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ @@ -2619,12 +2714,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { {9.0f, 10.0f}, {11.0f, 12.0f}}, }); - auto a = builder.ConstantR3FromArray3D(a_3d); - auto m = builder.ConstantR2({ + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto m = ConstantR2(&builder, { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); - builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, @@ -2644,12 +2739,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); - auto a = builder.ConstantR3FromArray3D(a_3d); + auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); - auto b = builder.ConstantR3FromArray3D(b_3d); + auto b = ConstantR3FromArray3D(&builder, b_3d); - builder.Gt(a, b); + Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); @@ -2684,9 +2779,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { } } - auto a = builder.ConstantR4FromArray4D(*operand_a_4d); - auto b = builder.ConstantR4FromArray4D(*operand_b_4d); - builder.Add(a, b); + auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); + auto b = ConstantR4FromArray4D(&builder, *operand_b_4d); + Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2712,9 +2807,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { } } - auto a = builder.ConstantR4FromArray4D(*operand_a_4d); - auto b = builder.ConstantR1(operand_b_1d); - builder.Add(a, b, {1}); + auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); + auto b = ConstantR1(&builder, operand_b_1d); + Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2732,9 +2827,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = builder.ConstantLiteral(*a_literal); - auto b = builder.ConstantR1(r1); - builder.Add(a, b, {1}); + auto a = ConstantLiteral(&builder, *a_literal); + auto b = ConstantR1(&builder, r1); + Add(a, b, {1}); for (int i0 = 0; i0 < d0; ++i0) { for (int i1 = 0; i1 < d1; ++i1) { @@ -2752,8 +2847,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); - auto x = builder.Parameter(0, shape, "x"); - builder.Add(x, x); + auto x = Parameter(&builder, 0, shape, "x"); + Add(x, x); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), @@ -2763,11 +2858,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2776,11 +2871,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b, /*broadcast_dimensions=*/{1, 0}); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); @@ -2797,10 +2892,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, y_literal->shape(), "y"); - auto slice = builder.Slice(x, {1}, {2}, {1}); - builder.Sub(slice, y); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto slice = Slice(x, {1}, {2}, {1}); + Sub(slice, y); ComputeAndCompareR1(&builder, {-2, -3}, {x_data.get(), y_data.get()}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index fcd9ff55e393f64476ddd4754e0fa74427f1cb51..8d15b7841bc7298cd6865d8689cc496c0459e4b9 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -29,10 +29,10 @@ class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { XlaBuilder builder("ax_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - builder.Mul(alpha, x); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + Mul(alpha, x); std::vector expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -42,11 +42,11 @@ TEST_F(AxpySimpleTest, AxTenValues) { XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1({}); - auto y = builder.ConstantR1({}); - auto ax = builder.Mul(alpha, x); - builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1(&builder, {}); + auto y = ConstantR1(&builder, {}); + auto ax = Mul(alpha, x); + Add(ax, y); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -54,13 +54,13 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { TEST_F(AxpySimpleTest, AxpyTenValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto y = builder.ConstantR1( - {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); - auto ax = builder.Mul(alpha, x); - builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = ConstantR1( + &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = Mul(alpha, x); + Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index 22c3394e6f34bd018ffaaaa4d9d68339673c3764..8c227df7f04e79ccc332062d0889d282c0f5e40f 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -35,10 +35,10 @@ class BadRngShapeValidationTest : public ClientLibraryTestBase {}; TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0.0); - auto one = builder.ConstantR0(1.0); + auto zero = ConstantR0(&builder, 0.0); + auto one = ConstantR0(&builder, 1.0); Shape default_constructed; - builder.RngUniform(zero, one, default_constructed); + RngUniform(zero, one, default_constructed); StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); @@ -49,13 +49,13 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0.0); - auto one = builder.ConstantR0(1.0); + auto zero = ConstantR0(&builder, 0.0); + auto one = ConstantR0(&builder, 1.0); Shape sans_layout; sans_layout.set_element_type(F32); sans_layout.add_dimensions(1); - builder.RngUniform(zero, one, sans_layout); + RngUniform(zero, one, sans_layout); StatusOr computation = builder.Build(); ASSERT_TRUE(computation.ok()); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index f3dac75a44b948c4b45b80b93e7462073010979e..d9d7ba1362a6975465971f4bc29da4d541e2f821 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -101,9 +101,9 @@ INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XlaBuilder builder("subtract_in_z_one_sample"); - auto x = builder.ConstantLiteral(input_literal_); - auto y = builder.ConstantR1({3.14, 4.25}); - builder.Sub(x, y, /*broadcast_dimensions=*/{1}); + auto x = ConstantLiteral(&builder, input_literal_); + auto y = ConstantR1(&builder, {3.14, 4.25}); + Sub(x, y, /*broadcast_dimensions=*/{1}); Array4D expected(kSamples, kZ, kY, kX); Array2D pz({ @@ -117,8 +117,8 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XlaBuilder builder("square_tesseract_elementwise"); - auto x = builder.ConstantLiteral(input_literal_); - builder.SquareF32(x); + auto x = ConstantLiteral(&builder, input_literal_); + SquareF32(x); using tensorflow::MathUtil; @@ -134,11 +134,10 @@ XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XLA_TEST_P(BatchNormalizationTest, SumToZ) { XlaBuilder builder("sum_to_z"); - auto input_activations = builder.ConstantLiteral(input_literal_); + auto input_activations = ConstantLiteral(&builder, input_literal_); XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all but the Z dimension. - builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, - {0, 2, 3}); + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, {0, 2, 3}); std::vector expected = {6, 12.6}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -146,13 +145,13 @@ XLA_TEST_P(BatchNormalizationTest, SumToZ) { XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XlaBuilder builder("square_and_reduce"); - auto input_activations = builder.ConstantLiteral(input_literal_); - auto set_means = builder.ConstantR1({2.f, 4.2f}); - auto activation_deviations = builder.Sub(input_activations, set_means, - /*broadcast_dimensions=*/{1}); + auto input_activations = ConstantLiteral(&builder, input_literal_); + auto set_means = ConstantR1(&builder, {2.f, 4.2f}); + auto activation_deviations = Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto dev_squares = builder.SquareF32(activation_deviations); - builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); + auto dev_squares = SquareF32(activation_deviations); + Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, {0, 2, 3}); std::vector expected = {18, 0.06}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -160,8 +159,8 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XlaBuilder builder("variance_to_stddev"); - auto variance = builder.ConstantR1({6.f, .02f}); - builder.SqrtF32(variance); + auto variance = ConstantR1(&builder, {6.f, .02f}); + SqrtF32(variance); std::vector expected = {2.44948974f, 0.14142136f}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -172,50 +171,50 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { XlaBuilder builder("batch_normalize_per_spec"); auto input_activations = - CheckShape(&builder, builder.ConstantLiteral(input_literal_), + CheckShape(&builder, ConstantLiteral(&builder, input_literal_), ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); - auto gamma = builder.ConstantR1({1.0, 1.0}); - auto beta = builder.ConstantR1({0.0, 0.0}); + auto gamma = ConstantR1(&builder, {1.0, 1.0}); + auto beta = ConstantR1(&builder, {0.0, 0.0}); XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all dimensions except dimension 1. Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); auto sum = CheckShape( &builder, - builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0, 2, 3}), + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); - auto count = builder.ConstantR0(ShapeUtil::ElementsIn(input_shape) / - ShapeUtil::ElementsIn(sum_shape)); - auto set_means = builder.Div(sum, count); + auto count = + ConstantR0(&builder, ShapeUtil::ElementsIn(input_shape) / + ShapeUtil::ElementsIn(sum_shape)); + auto set_means = Div(sum, count); const float kEpsilon = 1e-9f; - auto epsilon = builder.ConstantR0(kEpsilon); - auto epsilon2 = builder.ConstantR1({kEpsilon, kEpsilon}); - auto activation_deviations = builder.Sub(input_activations, set_means, - /*broadcast_dimensions=*/{1}); - auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = CheckShape( - &builder, - builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0, 2, 3}), - TwoElementVectorF32); - auto variance = builder.Div(sum_of_squares, count); - auto standard_deviation = builder.SqrtF32(variance); + auto epsilon = ConstantR0(&builder, kEpsilon); + auto epsilon2 = ConstantR1(&builder, {kEpsilon, kEpsilon}); + auto activation_deviations = Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); + auto dev_squares = SquareF32(activation_deviations); + auto sum_of_squares = + CheckShape(&builder, + Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), + TwoElementVectorF32); + auto variance = Div(sum_of_squares, count); + auto standard_deviation = SqrtF32(variance); auto standard_deviation_above_epsilon = - CheckShape(&builder, builder.Gt(standard_deviation, epsilon), + CheckShape(&builder, Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); - auto gt_eps = builder.Select(standard_deviation_above_epsilon, - standard_deviation, epsilon2); - auto normalization_factors = builder.ReciprocalF32(gt_eps); + auto gt_eps = + Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); + auto normalization_factors = ReciprocalF32(gt_eps); auto normalized_input_activations = - builder.Mul(activation_deviations, normalization_factors, - /*broadcast_dimensions=*/{1}); - /* auto output_activations = */ builder.Add( - builder.Mul(normalized_input_activations, gamma, - /*broadcast_dimensions=*/{1}), - beta, /*broadcast_dimensions=*/{1}); + Mul(activation_deviations, normalization_factors, + /*broadcast_dimensions=*/{1}); + /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma, + /*broadcast_dimensions=*/{1}), + beta, /*broadcast_dimensions=*/{1}); Array4D expected(kSamples, kZ, kY, kX); Array2D pz({ @@ -232,15 +231,15 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { const int kFeatureIndex = 3; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( - {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); + auto operand = ConstantR4FromArray4D( + &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); - auto scale = builder.ConstantR1({2.0f, 3.0f}); + auto scale = ConstantR1(&builder, {2.0f, 3.0f}); - auto offset = builder.ConstantR1({1.0f, 2.0f}); + auto offset = ConstantR1(&builder, {1.0f, 2.0f}); - builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, @@ -252,19 +251,20 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } -XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { +XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( + auto operand = ConstantR4FromArray4D( + &builder, {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - auto scale = builder.ConstantR1({2.0f, 3.0f}); + auto scale = ConstantR1(&builder, {2.0f, 3.0f}); - auto offset = builder.ConstantR1({1.0f, 2.0f}); + auto offset = ConstantR1(&builder, {1.0f, 2.0f}); - builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, @@ -294,8 +294,8 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { CreateR1Parameter(std::vector(260, 1.0f), /*parameter_number=*/2, "offset", &builder, &h2); - builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/1, kFeatureIndex); + BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) @@ -327,8 +327,8 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { /*parameter_number=*/2, "offset", &builder, &h2); // var = 125, mean = 15, epsilon = -100 - builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/-100, kFeatureIndex); + BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) @@ -346,19 +346,20 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { XlaBuilder builder(TestName()); auto operand = - builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); + ConstantR4FromArray4D(&builder, Array4D(2, 2, 2, 1, 0.0f)); - auto scale = builder.ConstantR1({1.0f, 1.0f}); + auto scale = ConstantR1(&builder, {1.0f, 1.0f}); - auto mean = builder.ConstantR1({0.0f, 0.0f}); + auto mean = ConstantR1(&builder, {0.0f, 0.0f}); - auto var = builder.ConstantR1({1.0f, 1.0f}); + auto var = ConstantR1(&builder, {1.0f, 1.0f}); - auto grad_output = builder.ConstantR4FromArray4D( + auto grad_output = ConstantR4FromArray4D( + &builder, {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); + BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, @@ -518,11 +519,11 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = Literal::CreateR4FromArray4D(input_array); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); auto scale_activations = - builder.Parameter(1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal->shape(), "offset"); auto offset_activations = - builder.Parameter(2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal->shape(), "scale"); auto expected = Literal::MakeTuple({expected_normalized.get(), Literal::CreateR1(mean).get(), @@ -535,8 +536,8 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { std::unique_ptr offset_data = client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); - builder.BatchNormTraining(input_activations, scale_activations, - offset_activations, epsilon, feature_index); + BatchNormTraining(input_activations, scale_activations, offset_activations, + epsilon, feature_index); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor @@ -618,14 +619,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = Literal::CreateR4FromArray4D(input_array); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); auto scale_activations = - builder.Parameter(1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal->shape(), "offset"); auto offset_activations = - builder.Parameter(2, offset_literal->shape(), "scale"); - auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal->shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); auto variance_activations = - builder.Parameter(4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal->shape(), "variance"); Array4D expected = normalized; @@ -640,9 +641,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { std::unique_ptr variance_data = client_->TransferToServer(*var_literal).ConsumeValueOrDie(); - builder.BatchNormInference(input_activations, scale_activations, - offset_activations, mean_activations, - variance_activations, epsilon, feature_index); + BatchNormInference(input_activations, scale_activations, offset_activations, + mean_activations, variance_activations, epsilon, + feature_index); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor @@ -807,12 +808,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = Literal::CreateR4FromArray4D(grad_output_array); - auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); - auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); - auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); - auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); + auto input_parameter = + Parameter(&builder, 0, input_literal->shape(), "input"); + auto scale_parameter = + Parameter(&builder, 1, scale_literal->shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); auto grad_output_parameter = - builder.Parameter(4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -825,9 +828,8 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { std::unique_ptr grad_output_data = client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); - builder.BatchNormGrad(input_parameter, scale_parameter, mean_parameter, - var_parameter, grad_output_parameter, epsilon, - feature_index); + BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, + grad_output_parameter, epsilon, feature_index); auto expected = Literal::MakeTuple({expected_grad_activation.get(), diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index ca337e78840e77377719636cd4cf33af2578210d..f40d03bea79de2a78814a0ad9f6cae6098d1449b 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -51,9 +51,9 @@ class Bfloat16Test : public ClientLibraryTestBase { XLA_TEST_F(Bfloat16Test, ScalarOperation) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(static_cast(2.0f)); - auto y = builder.ConstantR0(static_cast(1.0f)); - builder.Add(x, y); + auto x = ConstantR0(&builder, static_cast(2.0f)); + auto y = ConstantR0(&builder, static_cast(1.0f)); + Add(x, y); ComputeAndCompareR0(&builder, static_cast(3.0f), {}, error_spec_); @@ -61,8 +61,8 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { XLA_TEST_F(Bfloat16Test, LogOperation) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(static_cast(4.0f)); - builder.Log(x); + auto x = ConstantR0(&builder, static_cast(4.0f)); + Log(x); ComputeAndCompareR0(&builder, static_cast(1.387f), {}, error_spec_); @@ -70,7 +70,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { XLA_TEST_F(Bfloat16Test, NegateScalarF16) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(static_cast(2.1f))); + Neg(ConstantR0(&builder, static_cast(2.1f))); ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, error_spec_); @@ -80,20 +80,20 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( + auto operand = ConstantR4FromArray4D( + &builder, {{{{static_cast(1.f)}, {static_cast(2.f)}}, {{static_cast(3.f)}, {static_cast(4.f)}}}, {{{static_cast(5.f)}, {static_cast(6.f)}}, {{static_cast(7.f)}, {static_cast(8.f)}}}}); - auto scale = builder.ConstantR1( - {static_cast(2.0f), static_cast(3.0f)}); + auto scale = ConstantR1( + &builder, {static_cast(2.0f), static_cast(3.0f)}); - auto offset = builder.ConstantR1( - {static_cast(1.0f), static_cast(2.0f)}); + auto offset = ConstantR1( + &builder, {static_cast(1.0f), static_cast(2.0f)}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4( @@ -117,26 +117,27 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( - Array4D(2, 2, 2, 1, static_cast(0.0f))); + auto operand = ConstantR4FromArray4D( + &builder, Array4D(2, 2, 2, 1, static_cast(0.0f))); - auto scale = builder.ConstantR1( - {static_cast(1.0f), static_cast(1.0f)}); + auto scale = ConstantR1( + &builder, {static_cast(1.0f), static_cast(1.0f)}); - auto mean = builder.ConstantR1( - {static_cast(0.0f), static_cast(0.0f)}); + auto mean = ConstantR1( + &builder, {static_cast(0.0f), static_cast(0.0f)}); - auto var = builder.ConstantR1( - {static_cast(1.0f), static_cast(1.0f)}); + auto var = ConstantR1( + &builder, {static_cast(1.0f), static_cast(1.0f)}); - auto grad_output = builder.ConstantR4FromArray4D( + auto grad_output = ConstantR4FromArray4D( + &builder, {{{{static_cast(1.f)}, {static_cast(2.f)}}, {{static_cast(3.f)}, {static_cast(4.f)}}}, {{{static_cast(5.f)}, {static_cast(6.f)}}, {{static_cast(7.f)}, {static_cast(8.f)}}}}); - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); + BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4( diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 48203b1d40ea69ff00a57c2c9e42620739b23d59..20cb989751ad69e2f3cf97c87c43293951f599ab 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -33,9 +33,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -49,9 +49,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -65,9 +65,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -81,9 +81,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -94,11 +94,12 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { TEST_F(BinopScalingTest, R0PlusR2F32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR0(42.0); - auto rhs = builder.ConstantR2({ - {1.0, 2.0}, {3.0, 4.0}, - }); - builder.Add(lhs, rhs); + auto lhs = ConstantR0(&builder, 42.0); + auto rhs = ConstantR2(&builder, { + {1.0, 2.0}, + {3.0, 4.0}, + }); + Add(lhs, rhs); Array2D expected(2, 2); expected(0, 0) = 42.0 + 1.0; @@ -129,9 +130,9 @@ TEST_F(BinopScalingTest, R4PlusR0S32) { }); // clang-format on - auto lhs = builder.ConstantR4FromArray4D(lhs_array); - auto rhs = builder.ConstantR0(42); - builder.Add(lhs, rhs); + auto lhs = ConstantR4FromArray4D(&builder, lhs_array); + auto rhs = ConstantR0(&builder, 42); + Add(lhs, rhs); ComputeAndCompareR4(&builder, expected, {}); } diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index bff60f25ec8f15d372d251ac313200301a04f20f..d531e8fa82e47f7bcd278f10da2c205e44db0ac1 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -43,8 +43,8 @@ class BitcastConvertTest : public ClientLibraryTestBase { TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.BitcastConvertType(a, S32); + auto a = ConstantR1(&builder, {42, 64}); + BitcastConvertType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -52,8 +52,8 @@ TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + BitcastConvertType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -62,10 +62,10 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { XlaBuilder builder(TestName()); auto a = - builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, - static_cast(0xBF800000), 0x3F000000, - static_cast(0xBF000000)}); - builder.BitcastConvertType(a, F32); + ConstantR1(&builder, {0, static_cast(0x80000000), + 0x3F800000, static_cast(0xBF800000), + 0x3F000000, static_cast(0xBF000000)}); + BitcastConvertType(a, F32); std::vector expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f}; ComputeAndCompareR1(&builder, expected, {}); @@ -73,8 +73,8 @@ TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {}); + BitcastConvertType(a, F32); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}); @@ -82,8 +82,8 @@ XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.6, 64.4}); - builder.BitcastConvertType(a, S32); + auto a = ConstantR1(&builder, {42.6, 64.4}); + BitcastConvertType(a, S32); std::vector expected = {0x422a6666, 0x4280cccd}; ComputeAndCompareR1(&builder, expected, {}); @@ -91,9 +91,9 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { TEST_F(BitcastConvertTest, ConvertS32Extremes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {std::numeric_limits::min(), std::numeric_limits::max()}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::max()}); + BitcastConvertType(a, F32); std::vector expected = {-0.0f, NAN}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0, 0)); @@ -102,10 +102,10 @@ TEST_F(BitcastConvertTest, ConvertS32Extremes) { TEST_F(BitcastConvertTest, ConvertMapToS32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); - b->BitcastConvertType(param, S32); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in"); + BitcastConvertType(param, S32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {0x42280000, 0x42800000}; ComputeAndCompareR1(&builder, expected, {}); @@ -114,10 +114,10 @@ TEST_F(BitcastConvertTest, ConvertMapToS32) { TEST_F(BitcastConvertTest, ConvertMapToF32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); - b->BitcastConvertType(param, F32); - auto a = builder.ConstantR1({0x42280000, 0x42800000}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in"); + BitcastConvertType(param, F32); + auto a = ConstantR1(&builder, {0x42280000, 0x42800000}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -130,9 +130,9 @@ TEST_F(BitcastConvertTest, ConvertMapToF32) { // the new convert should have the same element type as the old convert. TEST_F(BitcastConvertTest, ConvertReshape) { XlaBuilder builder(TestName()); - auto input = builder.ConstantR1({0x42280000}); - auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); - builder.BitcastConvertType(reshape, F32); + auto input = ConstantR1(&builder, {0x42280000}); + auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + BitcastConvertType(reshape, F32); ComputeAndCompareR0(&builder, 42.0f, {}); } diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 3a0f51fc66d65c8684bd607b9e8103559cd4d8d4..5fdd1018a41413aa6f4e08a0c02a40bf17f1f882 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -37,17 +37,17 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { XlaBuilder* builder) { switch (op) { case HloOpcode::kMinimum: { - return builder->Min(lhs, rhs); + return Min(lhs, rhs); } case HloOpcode::kMaximum: { - return builder->Max(lhs, rhs); + return Max(lhs, rhs); } case HloOpcode::kMultiply: { - return builder->Mul(lhs, rhs); + return Mul(lhs, rhs); } default: { // Default to Add - return builder->Add(lhs, rhs); + return Add(lhs, rhs); } } } @@ -104,13 +104,13 @@ using ::testing::HasSubstr; XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(1.5), {}); + Broadcast(ConstantR0(&b, 1.5), {}); ComputeAndCompareR0(&b, 1.5, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {2, 3}); + Broadcast(ConstantR0(&b, 2.25), {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } @@ -122,7 +122,7 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", /*builder=*/&b, /*data_handle=*/&src); - b.Broadcast(src, {2, 3}); + Broadcast(src, {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {param_data.get()}, ErrorSpec(0.0001)); @@ -130,21 +130,21 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {2, 0}); + Broadcast(ConstantR0(&b, 2.25), {2, 0}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {0, 2}); + Broadcast(ConstantR0(&b, 2.25), {0, 2}); Array2D expected(0, 2); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({1, 2, 3}), {2}); + Broadcast(ConstantR1(&b, {1, 2, 3}), {2}); Array2D expected(2, 3); expected(0, 0) = 1; @@ -172,7 +172,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaOp x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); - b.And(x, y, /*broadcast_dimensions=*/{1, 2}); + And(x, y, /*broadcast_dimensions=*/{1, 2}); Array3D expected(2, 2, 1); expected(0, 0, 0) = false; @@ -185,7 +185,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({}), {2}); + Broadcast(ConstantR1(&b, {}), {2}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); @@ -193,7 +193,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({1, 2, 3}), {0}); + Broadcast(ConstantR1(&b, {1, 2, 3}), {0}); Array2D expected(0, 3); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); @@ -209,10 +209,10 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { // dimensions. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 5.0}}), - b.ConstantLiteral(*Literal::CreateR3( - {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), - /*broadcast_dimensions=*/{1, 2}); + Add(ConstantR2(&b, {{1.0, 5.0}}), + ConstantLiteral(&b, *Literal::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); auto expected = Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, @@ -260,9 +260,10 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape, &r3_implicit_array, 1.0, 0.2, 56789); - auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); - auto r3_parameter = builder.Parameter(1, r3_shape, "input"); - XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); + auto r3_implicit_parameter = + Parameter(&builder, 0, r3_implicit_shape, "input"); + auto r3_parameter = Parameter(&builder, 1, r3_shape, "input"); + BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); @@ -306,7 +307,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); - b.Add(r3h, r1h); + Add(r3h, r1h); auto expected = Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); @@ -317,10 +318,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); @@ -330,10 +331,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); @@ -343,10 +344,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}, {3, 4}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); @@ -356,10 +357,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = + ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); @@ -370,10 +372,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = - b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); @@ -383,10 +385,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}}})); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); @@ -509,14 +511,14 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789); auto r2_implicit_parameter1 = - builder.Parameter(0, r2_implicit_shape1, "input0"); - auto r2_parameter = builder.Parameter(1, r2_shape, "input1"); + Parameter(&builder, 0, r2_implicit_shape1, "input0"); + auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1"); auto r2_implicit_parameter2 = - builder.Parameter(2, r2_implicit_shape2, "input2"); + Parameter(&builder, 2, r2_implicit_shape2, "input2"); XlaOp op1 = BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); - XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); + BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); Array2D expected_array(spec.output_bounds[0], spec.output_bounds[1]); @@ -544,9 +546,9 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); - auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); - b.Add(r2, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}})); + auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + Add(r2, r1); auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); @@ -555,9 +557,9 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); - auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); - b.Add(r2, r1); + auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1}, {2}})); + auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + Add(r2, r1); auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); @@ -566,10 +568,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1, {0}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1, {0}); auto expected = Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); @@ -579,10 +581,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r1, r3, {1}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r1, r3, {1}); auto expected = Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); @@ -592,10 +594,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r1, r3, {2}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r1, r3, {2}); auto expected = Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); @@ -605,17 +607,17 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { XlaBuilder b(TestName()); - auto r1_0 = b.ConstantR1({1000, 2000}); - auto r1_1 = b.ConstantR1({100, 200}); - auto r1_2 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + auto r1_0 = ConstantR1(&b, {1000, 2000}); + auto r1_1 = ConstantR1(&b, {100, 200}); + auto r1_2 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { - r3 = b.Add(r1_0, r3, {0}); - r3 = b.Add(r3, r1_1, {1}); - r3 = b.Add(r1_2, r3, {2}); + r3 = Add(r1_0, r3, {0}); + r3 = Add(r3, r1_1, {1}); + r3 = Add(r1_2, r3, {2}); } - r3 = b.Mul(r3, b.ConstantR0(-2)); + r3 = Mul(r3, ConstantR0(&b, -2)); auto expected = Literal::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, @@ -626,17 +628,17 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { XlaBuilder b(TestName()); - auto r1_0 = b.ConstantR1({1000, 2000}); - auto r1_1 = b.ConstantR1({100, 200}); - auto r1_2 = b.ConstantR1({10, 20}); - auto r0 = b.ConstantR0(3); - auto r3 = b.Broadcast(r0, {2, 2, 2}); + auto r1_0 = ConstantR1(&b, {1000, 2000}); + auto r1_1 = ConstantR1(&b, {100, 200}); + auto r1_2 = ConstantR1(&b, {10, 20}); + auto r0 = ConstantR0(&b, 3); + auto r3 = Broadcast(r0, {2, 2, 2}); for (int i = 0; i < 3; ++i) { - r3 = b.Add(r1_0, r3, {0}); - r3 = b.Add(r3, r1_1, {1}); - r3 = b.Add(r1_2, r3, {2}); + r3 = Add(r1_0, r3, {0}); + r3 = Add(r3, r1_1, {1}); + r3 = Add(r1_2, r3, {2}); } - r3 = b.Mul(r3, b.ConstantR0(-1)); + r3 = Mul(r3, ConstantR0(&b, -1)); auto expected = Literal::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, @@ -650,10 +652,10 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // results in a shape incompatible with the lhs [2, 3, 1]. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), - b.ConstantLiteral(*Literal::CreateR3( - {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), - /*broadcast_dimensions=*/{1, 2}); + Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), + ConstantLiteral(&b, *Literal::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); @@ -665,8 +667,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 2.0}}), - b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + Add(ConstantR2(&b, {{1.0, 2.0}}), + ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); @@ -678,8 +680,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 2.0}}), - b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + Add(ConstantR2(&b, {{1.0, 2.0}}), + ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 5fd33b50c94356839bbed58acd43b7d0286f4a7e..bc64a19ce22072152216a7c150fbd16480d261fb 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -34,7 +34,7 @@ class CallOpTest : public ClientLibraryTestBase { protected: XlaComputation CreateR0F32IdentityComputation() { XlaBuilder builder("Identity"); - builder.Parameter(0, r0f32_, "x"); + Parameter(&builder, 0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -42,9 +42,9 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR1S0F32AdditionComputation() { XlaBuilder builder("Addition"); - auto x = builder.Parameter(0, r1s0f32_, "x"); - auto y = builder.Parameter(1, r1s0f32_, "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, r1s0f32_, "x"); + auto y = Parameter(&builder, 1, r1s0f32_, "y"); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -52,9 +52,9 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR1S2F32AdditionComputation() { XlaBuilder builder("Addition"); - auto x = builder.Parameter(0, r1s2f32_, "x"); - auto y = builder.Parameter(1, r1s2f32_, "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, r1s2f32_, "x"); + auto y = Parameter(&builder, 1, r1s2f32_, "y"); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -62,7 +62,7 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR0F32TupleComputation() { XlaBuilder builder("Tuple"); - builder.Tuple({builder.Parameter(0, r0f32_, "x")}); + Tuple(&builder, {Parameter(&builder, 0, r0f32_, "x")}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -76,8 +76,8 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); - builder.Call(callee, {constant}); + auto constant = ConstantLiteral(&builder, *Literal::CreateR0(42.0)); + Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); } @@ -85,9 +85,9 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = builder.ConstantLiteral(*Literal::CreateR1({})); - auto y = builder.ConstantLiteral(*Literal::CreateR1({})); - builder.Call(callee, {x, y}); + auto x = ConstantLiteral(&builder, *Literal::CreateR1({})); + auto y = ConstantLiteral(&builder, *Literal::CreateR1({})); + Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); } @@ -95,9 +95,9 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); - auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); - auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); - builder.Call(callee, {x, y}); + auto x = ConstantLiteral(&builder, *Literal::CreateR1({1.0f, 2.0f})); + auto y = ConstantLiteral(&builder, *Literal::CreateR1({2.0f, 3.0f})); + Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); } @@ -105,26 +105,26 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { XlaBuilder builder("inner"); { - auto x = builder.Parameter(0, r0f32_, "x"); - builder.Add(x, builder.ConstantR0(1.0)); + auto x = Parameter(&builder, 0, r0f32_, "x"); + Add(x, ConstantR0(&builder, 1.0)); } TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build()); XlaBuilder builder2("outer"); { - auto x = builder2.Parameter(0, r0f32_, "x"); - x = builder2.Call(inner, {x}); - x = builder2.Call(inner, {x}); - x = builder2.Call(inner, {x}); + auto x = Parameter(&builder2, 0, r0f32_, "x"); + x = Call(&builder2, inner, {x}); + x = Call(&builder2, inner, {x}); + x = Call(&builder2, inner, {x}); } TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build()); XlaBuilder builder3("outermost"); { - auto x = builder3.Parameter(0, r0f32_, "x"); - x = builder3.Call(outer, {x}); - x = builder3.Call(outer, {x}); - x = builder3.Call(outer, {x}); + auto x = Parameter(&builder3, 0, r0f32_, "x"); + x = Call(&builder3, outer, {x}); + x = Call(&builder3, outer, {x}); + x = Call(&builder3, outer, {x}); } TF_ASSERT_OK_AND_ASSIGN( @@ -138,7 +138,7 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaComputation callee = CreateR0F32TupleComputation(); auto elem = Literal::CreateR0(42.0); auto tuple = Literal::MakeTuple({elem.get()}); - builder.Call(callee, {builder.ConstantLiteral(*elem)}); + Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); } diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 660ff0cad5666219a4a7cb1eedbed03f06e651ba..1ad57c075b22c7730ffd8d1beeab60c9d5dc7458 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -38,9 +38,9 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); auto param_literal = Literal::CreateR1({1.1f, 2.2f}); - auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + Add(p0, p1); auto param0_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -77,9 +77,9 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { XlaBuilder builder("add_two_params"); - auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); - auto add = builder.Mul(p0, p1); + auto p0 = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); + auto p1 = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {4}), "param1"); + Mul(p0, p1); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index bf8ed4d9fb0bc61b86ef0b5872711a122a3d416b..dafd6ebabbe6edafc1c926677b3ea00e775be010 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -486,11 +487,11 @@ ClientLibraryTestBase::ComputeValueAndReference( XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaBuilder builder("relu"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto z_value = builder.Parameter(0, shape, "z_value"); + auto z_value = Parameter(&builder, 0, shape, "z_value"); auto zero = use_bfloat16_ - ? builder.ConstantR0(static_cast(0.0f)) - : builder.ConstantR0(0.0f); - builder.Max(z_value, zero); + ? ConstantR0(&builder, static_cast(0.0f)) + : ConstantR0(&builder, 0.0f); + Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -499,9 +500,9 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaBuilder builder("max"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto x = builder.Parameter(0, shape, "x"); - auto y = builder.Parameter(1, shape, "y"); - builder.Max(x, y); + auto x = Parameter(&builder, 0, shape, "x"); + auto y = Parameter(&builder, 1, shape, "y"); + Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -510,13 +511,13 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { XlaBuilder builder("relu_sensitivity"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto activation = builder.Parameter(0, shape, "activation"); - auto backprop = builder.Parameter(1, shape, "backprop"); + auto activation = Parameter(&builder, 0, shape, "activation"); + auto backprop = Parameter(&builder, 1, shape, "backprop"); auto zero = use_bfloat16_ - ? builder.ConstantR0(static_cast(0.0f)) - : builder.ConstantR0(0.0f); - auto activation_gtz = builder.Gt(activation, zero); - builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); + ? ConstantR0(&builder, static_cast(0.0f)) + : ConstantR0(&builder, 0.0f); + auto activation_gtz = Gt(activation, zero); + Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -559,8 +560,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { - return builder->ConstantLiteral( - use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); + return ConstantLiteral( + builder, use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); } std::unique_ptr @@ -588,7 +589,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( client_->TransferToServer(*param_literal, device_handle) .ConsumeValueOrDie(); *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); + Parameter(builder, parameter_number, param_literal->shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 0499fec5898a42affa0e0a712dee10187355c13e..5361ae6783c4c103cf923ffbda066165545c39a1 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -373,6 +373,13 @@ class ClientLibraryTestBase : public ::testing::Test { // The float type used in this test, BF16 or F32 according to use_bfloat16. PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // Executes the computation and calculates the expected reference value using + // the reference client. Returns two literals in the order of (expected, + // actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + Client* client_; Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; @@ -390,13 +397,6 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the reference client. Returns two literals in the order of (expected, - // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; @@ -545,7 +545,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -559,7 +559,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -573,7 +573,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -587,7 +587,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 08671cf62445826649b5c97003f998ae98a59d97..831b863998f1cab31d37aa4474be45d8531075ac 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -43,8 +43,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& execute_layout : layouts) { for (const std::vector& transfer_layout : layouts) { - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})); + Add(ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; @@ -72,8 +72,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { XlaBuilder b(TestName()); - b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})}); + Tuple(&b, {ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})}); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); @@ -117,8 +117,8 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); - b.Add(b.Parameter(0, shape, "param_0"), - b.ConstantR2({{1, 2}, {3, 4}})); + Add(Parameter(&b, 0, shape, "param_0"), + ConstantR2(&b, {{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); // We can't really test parallel execution on CPU since all of the cores in a diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 50a006964869b3e5dce431d441f7cd81af9df910..eb211dd8ff376fb0da03b3e68be1d849970d96fd 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -77,7 +77,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { // TODO(b/74197823): Disabled because there is no cache in the new design. XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(42.0)); + Neg(ConstantR0(&builder, 42.0)); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); @@ -99,7 +99,7 @@ XLA_TEST_F(CompilationCacheTest, .ConsumeValueOrDie(); XlaBuilder builder(TestName()); - builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param")); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, @@ -115,16 +115,16 @@ XLA_TEST_F(CompilationCacheTest, // TODO(b/74197823): Disabled because there is no cache in the new design. XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { XlaBuilder builder_neg(TestName() + "_neg"); - builder_neg.Neg(builder_neg.ConstantR0(42.0)); + Neg(ConstantR0(&builder_neg, 42.0)); XlaComputation computation_neg = builder_neg.Build().ConsumeValueOrDie(); XlaBuilder builder_exp(TestName() + "_exp"); - builder_exp.Exp(builder_exp.ConstantR0(1.0)); + Exp(ConstantR0(&builder_exp, 1.0)); XlaComputation computation_exp = builder_exp.Build().ConsumeValueOrDie(); XlaBuilder builder_add(TestName() + "_add"); - builder_add.Add(builder_add.ConstantR0(2.0), - builder_add.ConstantR0(3.0)); + Add(ConstantR0(&builder_add, 2.0), + ConstantR0(&builder_add, 3.0)); XlaComputation computation_add = builder_add.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation_neg, {}, -42.0, @@ -154,7 +154,7 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR2F32(computation, {colmaj_handle.get()}, diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ba22530f1cfee56337f862c25122d399dbf0f1e4..1a396b090c615dbd829964bd68ebda74df29c71e 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.ConstantR0(42); + auto computation = ConstantR0(&b, 42); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + Add(ConstantR0(&b, 42.5f), ConstantR0(&b, 1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), - ShapeUtil::MakeShape(F32, {})); + RngUniform(ConstantR0(&b, 1.1f), ConstantR0(&b, 2.1f), + ShapeUtil::MakeShape(F32, {})); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Add(ConstantR0(&b, 1.0f), + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = - b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + Add(ConstantR0(&b, 2.5f), ConstantR0(&b, 1.5f)); + auto not_constant_a = Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1"); auto constant_9 = - b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + Mul(ConstantR0(&b, 2.0f), ConstantR0(&b, 4.5f)); + auto not_constant_b = Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = Add(constant_4, constant_9); + Add(not_constant_b, Add(constant_13, not_constant_a)); EXPECT_TRUE(IsConstant(constant_13, &b)); @@ -201,7 +201,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + Add(ConstantR1(&b, {1, 2}), ConstantR1(&b, {3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, @@ -216,7 +216,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + auto computation = Div(ConstantR0(&b, 15), ConstantR0(&b, 3)); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, @@ -237,8 +237,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, ComputeConstantLiteral( client, - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), + Add(ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); std::unique_ptr expected_literal = diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 352864502a184237fde600330836fe471a5444f2..1161b560b7b0756556911812666c6f4fe9179f72 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -39,7 +39,7 @@ using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { XlaBuilder builder(TestName()); - builder.ConcatInDim({}, 0); + ConcatInDim(&builder, {}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), @@ -49,8 +49,8 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { // Concatenate with one argument works. XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - builder.ConcatInDim({a}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + ConcatInDim(&builder, {a}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -58,8 +58,8 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.ConcatInDim({a}, 0); + auto a = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -69,9 +69,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(42.0); - auto b = builder.ConstantR0(64.0); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR0(&builder, 42.0); + auto b = ConstantR0(&builder, 64.0); + ConcatInDim(&builder, {a, b}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), @@ -80,9 +80,9 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -90,9 +90,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -100,9 +100,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - auto b = builder.ConstantR1({}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + auto b = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -110,9 +110,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - auto b = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + auto b = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -130,9 +130,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(lhs); - auto b = builder.ConstantR1(rhs); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, lhs); + auto b = ConstantR1(&builder, rhs); + ConcatInDim(&builder, {a, b}, 0); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -140,9 +140,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { for (int dim : {0, 1}) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto b = builder.ConstantR2FromArray2D(Array2D(0, 0)); - builder.ConcatInDim({a, b}, dim); + auto a = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + auto b = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + ConcatInDim(&builder, {a, b}, dim); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, ErrorSpec(0.0001)); @@ -153,9 +153,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 0); Array2D expected({ {0}, @@ -168,9 +168,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); Array2D expected({ {0, 64}, @@ -181,9 +181,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { XLA_TEST_F(ConcatTest, Concat2x0With2x5) { XlaBuilder builder(TestName()); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); ComputeAndCompareR2(&builder, *b_array, {}, ErrorSpec(0.0001)); } @@ -192,9 +192,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(2, 3); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); Array2D expected({ {0, 1, 2, 64, 65, 66, 67, 68}, @@ -206,9 +206,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { XLA_TEST_F(ConcatTest, Concat3x2With0x2) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(Array2D(0, 2)); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + ConcatInDim(&builder, {a, b}, 0); ComputeAndCompareR2(&builder, *a_array, {}, ErrorSpec(0.0001)); } @@ -217,9 +217,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 0); Array2D expected({ {0, 1}, @@ -236,9 +236,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR3FromArray3D(Array3D(3, 0, 2)); - auto b = builder.ConstantR3FromArray3D(Array3D(3, 0, 1)); - builder.ConcatInDim({a, b}, 2); + auto a = ConstantR3FromArray3D(&builder, Array3D(3, 0, 2)); + auto b = ConstantR3FromArray3D(&builder, Array3D(3, 0, 1)); + ConcatInDim(&builder, {a, b}, 2); ComputeAndCompareR3(&builder, Array3D(3, 0, 3), {}, ErrorSpec(0.0001)); } @@ -257,9 +257,9 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { {{7}}, {{8}}, }); - auto a = builder.ConstantR3FromArray3D(a_array); - auto b = builder.ConstantR3FromArray3D(b_array); - builder.ConcatInDim({a, b}, 2); + auto a = ConstantR3FromArray3D(&builder, a_array); + auto b = ConstantR3FromArray3D(&builder, b_array); + ConcatInDim(&builder, {a, b}, 2); Array3D expected({ {{0, 1, 6}}, @@ -271,10 +271,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b, c}, 0); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b, c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -300,10 +300,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { {{7}}, {{11}}, }); - auto a = builder.ConstantR3FromArray3D(a_array); - auto b = builder.ConstantR3FromArray3D(b_array); - auto c = builder.ConstantR3FromArray3D(c_array); - builder.ConcatInDim({a, b, c}, 2); + auto a = ConstantR3FromArray3D(&builder, a_array); + auto b = ConstantR3FromArray3D(&builder, b_array); + auto c = ConstantR3FromArray3D(&builder, c_array); + ConcatInDim(&builder, {a, b, c}, 2); Array3D expected({ {{0, 1, 2, 3}}, @@ -315,11 +315,11 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); // concatenated = (a concat b) concat c - builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -327,11 +327,11 @@ XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); // concatenated = a concat (b concat c) - builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -346,9 +346,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 0); Array2D expected(2, 1024); for (int i = 0; i < 1024; ++i) { @@ -367,9 +367,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 1); Array2D expected(1, 2048); for (int i = 0; i < 1024; ++i) { @@ -392,9 +392,9 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 1); Array2D expected(64, 66); for (int i0 = 0; i0 < 64; ++i0) { @@ -410,9 +410,9 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { XlaBuilder builder(TestName()); auto opaque_shape = ShapeUtil::MakeOpaqueShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); - auto x = builder.Parameter(0, r1f32, "x"); - auto y = builder.Parameter(1, opaque_shape, "y"); - builder.ConcatInDim({x, y}, 0); + auto x = Parameter(&builder, 0, r1f32, "x"); + auto y = Parameter(&builder, 1, opaque_shape, "y"); + ConcatInDim(&builder, {x, y}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( @@ -425,9 +425,9 @@ XLA_TEST_F(ConcatTest, CannotConcatTokens) { XlaBuilder builder(TestName()); auto token_shape = ShapeUtil::MakeTokenShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); - auto x = builder.Parameter(0, r1f32, "x"); - auto y = builder.Parameter(1, token_shape, "y"); - builder.ConcatInDim({x, y}, 0); + auto x = Parameter(&builder, 0, r1f32, "x"); + auto y = Parameter(&builder, 1, token_shape, "y"); + ConcatInDim(&builder, {x, y}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( @@ -437,10 +437,10 @@ XLA_TEST_F(ConcatTest, CannotConcatTokens) { XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { XlaBuilder builder(TestName()); - auto p0 = builder.ConstantR1({true}); - auto p1 = builder.ConstantR1({false}); - auto p2 = builder.ConstantR1({true}); - builder.ConcatInDim({p0, p1, p2}, 0); + auto p0 = ConstantR1(&builder, {true}); + auto p1 = ConstantR1(&builder, {false}); + auto p2 = ConstantR1(&builder, {true}); + ConcatInDim(&builder, {p0, p1, p2}, 0); bool expected[] = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); @@ -448,11 +448,11 @@ XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { XlaBuilder builder(TestName()); - auto a0 = builder.ConstantR1({1}); - auto a1 = builder.ConstantR1({2, 3}); - auto a2 = builder.ConstantR1({4, 5, 6}); - auto a3 = builder.ConstantR1({7, 8, 9, 10}); - builder.ConcatInDim({a0, a1, a2, a3}, 0); + auto a0 = ConstantR1(&builder, {1}); + auto a1 = ConstantR1(&builder, {2, 3}); + auto a2 = ConstantR1(&builder, {4, 5, 6}); + auto a3 = ConstantR1(&builder, {7, 8, 9, 10}); + ConcatInDim(&builder, {a0, a1, a2, a3}, 0); std::vector expected(10); std::iota(expected.begin(), expected.end(), 1); @@ -487,7 +487,7 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", &builder, &h1); - builder.ConcatInDim({h0, h1}, 2); + ConcatInDim(&builder, {h0, h1}, 2); ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } @@ -514,9 +514,9 @@ TEST_P(ConcatR2BinaryTest, DoIt) { rhs.FillUnique(1000); XlaBuilder builder(TestName()); - auto a0 = builder.ConstantR2FromArray2D(lhs); - auto a1 = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a0, a1}, spec.concat_dimension); + auto a0 = ConstantR2FromArray2D(&builder, lhs); + auto a1 = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a0, a1}, spec.concat_dimension); std::unique_ptr> expected = ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension); @@ -540,13 +540,13 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, f32_scalar, "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto mul = builder.Mul(x, y); - auto add1 = builder.Add(mul, builder.ConstantR1({1.f, 2.f})); - auto add2 = builder.Add(mul, builder.ConstantR1({3.f, 4.f})); - auto add3 = builder.Add(mul, builder.ConstantR1({5.f, 6.f})); - builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0); + auto x = Parameter(&builder, 0, f32_scalar, "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto mul = Mul(x, y); + auto add1 = Add(mul, ConstantR1(&builder, {1.f, 2.f})); + auto add2 = Add(mul, ConstantR1(&builder, {3.f, 4.f})); + auto add3 = Add(mul, ConstantR1(&builder, {5.f, 6.f})); + ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0); ComputeAndCompareR1(&builder, {7., 8., 9., 10., 11., 12.}, {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); @@ -564,13 +564,13 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto z = builder.Parameter(2, f32_scalar, "z"); - auto bcast = builder.Broadcast(y, {5}); - auto bcast2 = builder.Broadcast(z, {3}); - auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0); - builder.ConcatInDim({concat, bcast2}, /*dimension=*/0); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto z = Parameter(&builder, 2, f32_scalar, "z"); + auto bcast = Broadcast(y, {5}); + auto bcast2 = Broadcast(z, {3}); + auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0); + ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0); ComputeAndCompareR1( &builder, @@ -592,13 +592,13 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto z = builder.Parameter(2, f32_scalar, "y"); - auto y_bcast = builder.Broadcast(y, {1, 5, 7}); - auto z_bcast = builder.Broadcast(z, {4, 1, 7}); - auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0); - builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto z = Parameter(&builder, 2, f32_scalar, "y"); + auto y_bcast = Broadcast(y, {1, 5, 7}); + auto z_bcast = Broadcast(z, {4, 1, 7}); + auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0); + ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1); Array3D y_bcast3d(1, 5, 7, 1.5f); Array3D z_bcast3d(4, 1, 7, 5.5f); auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 7ff6706935740c7d76ee5cd03eae292386760397..ee3c83039bfc13f6ad78111d92ba0f8387a3ade3 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -26,8 +26,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { protected: XlaComputation CreateR0ConstantComputation(float value) { XlaBuilder builder("Constant"); - builder.Parameter(0, empty_tuple_, "tuple"); - builder.ConstantR0(value); + Parameter(&builder, 0, empty_tuple_, "tuple"); + ConstantR0(&builder, value); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -35,7 +35,7 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateR0IdentityComputation() { XlaBuilder builder("Identity"); - builder.Parameter(0, r0f32_, "x"); + Parameter(&builder, 0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -43,8 +43,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateCeilComputation(const Shape& shape) { XlaBuilder builder("Ceil"); - auto param = builder.Parameter(0, shape, "param"); - builder.Ceil(param); + auto param = Parameter(&builder, 0, shape, "param"); + Ceil(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -60,8 +60,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateFloorComputation(const Shape& shape) { XlaBuilder builder("Floor"); - auto param = builder.Parameter(0, shape, "param"); - builder.Floor(param); + auto param = Parameter(&builder, 0, shape, "param"); + Floor(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -78,12 +78,12 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleCeilComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - auto x_ceil = builder.Ceil(x); - auto y_ceil = builder.Ceil(y); - builder.Tuple({x_ceil, y_ceil}); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + auto x_ceil = Ceil(x); + auto y_ceil = Ceil(y); + Tuple(&builder, {x_ceil, y_ceil}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -100,12 +100,12 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleFloorComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - auto x_floor = builder.Floor(x); - auto y_floor = builder.Floor(y); - builder.Tuple({x_floor, y_floor}); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + auto x_floor = Floor(x); + auto y_floor = Floor(y); + Tuple(&builder, {x_floor, y_floor}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -122,10 +122,10 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleAddComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - builder.Add(x, y); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -142,10 +142,10 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleSubComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - builder.Sub(x, y); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + Sub(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -172,12 +172,11 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({}); + auto pred = ConstantR0(&builder, true); + auto operands = Tuple(&builder, {}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); - builder.Conditional(pred, operands, true_computation, operands, - false_computation); + Conditional(pred, operands, true_computation, operands, false_computation); ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); } @@ -185,11 +184,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) { // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); auto identity = CreateR0IdentityComputation(); - builder.Conditional(pred, operand1, identity, operand2, identity); + Conditional(pred, operand1, identity, operand2, identity); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -198,11 +197,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, - CreateR0FloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -211,10 +210,10 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand, CreateR0CeilComputation(), operand, - CreateR0FloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operand = ConstantR0(&builder, 12.6f); + Conditional(pred, operand, CreateR0CeilComputation(), operand, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -223,11 +222,11 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); - builder.Conditional(pred, operand1, floor, operand2, floor); + Conditional(pred, operand1, floor, operand2, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -236,10 +235,10 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand = builder.ConstantR0(12.6f); + auto pred = ConstantR0(&builder, false); + auto operand = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); - builder.Conditional(pred, operand, floor, operand, floor); + Conditional(pred, operand, floor, operand, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -248,11 +247,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2, - CreateR0FloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Conditional(pred, operand1, CreateR0FloorComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -261,19 +260,19 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); XlaBuilder inner_builder(TestName() + ".inner_conditional"); - auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); - auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); - auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); - inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(), - false_operand, CreateR0FloorComputation()); + auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0"); + auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1"); + auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2"); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); auto inner_builder_result = inner_builder.Build(); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Call(inner_builder_result.ConsumeValueOrDie(), - {pred, operand1, operand2}); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Call(&builder, inner_builder_result.ConsumeValueOrDie(), + {pred, operand1, operand2}); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -282,12 +281,12 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, - CreateR0TupleSubComputation()); + auto pred = ConstantR0(&builder, true); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); } @@ -296,12 +295,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, - CreateR0TupleSubComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); } @@ -310,12 +309,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR1({24.0f, 56.0f}); - auto operand2 = builder.ConstantR1({10.0f, 11.0f}); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR1TupleSubComputation()); + auto pred = ConstantR0(&builder, true); + auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); } @@ -324,12 +323,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR1({24.0f, 56.0f}); - auto operand2 = builder.ConstantR1({10.0f, 11.0f}); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR1TupleSubComputation()); + auto pred = ConstantR0(&builder, false); + auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); } @@ -337,11 +336,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operands = builder.Tuple( - {builder.ConstantR0(12.2f), builder.ConstantR0(25.6f)}); - builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, - CreateR0TupleFloorComputation()); + auto pred = ConstantR0(&builder, false); + auto operands = Tuple(&builder, {ConstantR0(&builder, 12.2f), + ConstantR0(&builder, 25.6f)}); + Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, + CreateR0TupleFloorComputation()); ComputeAndCompareTuple( &builder, @@ -353,11 +352,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), - builder.ConstantR1({25.6f, 29.2f})}); - builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, - CreateR1TupleFloorComputation()); + auto pred = ConstantR0(&builder, true); + auto operands = + Tuple(&builder, {ConstantR1(&builder, {12.2f, 15.8f}), + ConstantR1(&builder, {25.6f, 29.2f})}); + Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, + CreateR1TupleFloorComputation()); ComputeAndCompareTuple( &builder, @@ -371,31 +371,31 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { XlaBuilder true_builder(TestName() + ".true"); { - true_builder.Parameter(0, empty_tuple_, "tuple"); - auto true_pred = true_builder.ConstantR0(true); - auto true_scalar = true_builder.ConstantR0(12.2f); - auto true_array = true_builder.ConstantR1({12.8f, 14.6f}); - true_builder.Tuple({true_pred, true_scalar, true_array}); + Parameter(&true_builder, 0, empty_tuple_, "tuple"); + auto true_pred = ConstantR0(&true_builder, true); + auto true_scalar = ConstantR0(&true_builder, 12.2f); + auto true_array = ConstantR1(&true_builder, {12.8f, 14.6f}); + Tuple(&true_builder, {true_pred, true_scalar, true_array}); } auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); XlaBuilder false_builder(TestName() + ".false"); { - false_builder.Parameter(0, empty_tuple_, "tuple"); - auto false_pred = false_builder.ConstantR0(false); - auto false_scalar = false_builder.ConstantR0(25.6f); - auto false_array = false_builder.ConstantR1({26.4f, 32.6f}); - false_builder.Tuple({false_pred, false_scalar, false_array}); + Parameter(&false_builder, 0, empty_tuple_, "tuple"); + auto false_pred = ConstantR0(&false_builder, false); + auto false_scalar = ConstantR0(&false_builder, 25.6f); + auto false_array = ConstantR1(&false_builder, {26.4f, 32.6f}); + Tuple(&false_builder, {false_pred, false_scalar, false_array}); } auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({}); - builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), - operands, false_builder_result.ConsumeValueOrDie()); + auto pred = ConstantR0(&builder, true); + auto operands = Tuple(&builder, {}); + Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, + false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, @@ -409,36 +409,37 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { XlaBuilder true_builder(TestName() + ".true"); { - true_builder.Parameter(0, empty_tuple_, "tuple"); - auto true_constant1 = true_builder.ConstantR0(12.2f); - auto true_constant2 = true_builder.ConstantR1({12.8f, 14.6f}); - auto true_constant3 = true_builder.ConstantR1({25.4f, 29.8f}); - auto true_constant4 = true_builder.ConstantR0(35.6f); - true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}), - true_builder.Tuple({true_constant3, true_constant4})}); + Parameter(&true_builder, 0, empty_tuple_, "tuple"); + auto true_constant1 = ConstantR0(&true_builder, 12.2f); + auto true_constant2 = ConstantR1(&true_builder, {12.8f, 14.6f}); + auto true_constant3 = ConstantR1(&true_builder, {25.4f, 29.8f}); + auto true_constant4 = ConstantR0(&true_builder, 35.6f); + Tuple(&true_builder, + {Tuple(&true_builder, {true_constant1, true_constant2}), + Tuple(&true_builder, {true_constant3, true_constant4})}); } auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); XlaBuilder false_builder(TestName() + ".false"); { - false_builder.Parameter(0, empty_tuple_, "tuple"); - auto false_constant1 = false_builder.ConstantR0(46.6f); - auto false_constant2 = false_builder.ConstantR1({54.4f, 58.4f}); - auto false_constant3 = false_builder.ConstantR1({62.1f, 67.4f}); - auto false_constant4 = false_builder.ConstantR0(9.3f); - false_builder.Tuple( - {false_builder.Tuple({false_constant1, false_constant2}), - false_builder.Tuple({false_constant3, false_constant4})}); + Parameter(&false_builder, 0, empty_tuple_, "tuple"); + auto false_constant1 = ConstantR0(&false_builder, 46.6f); + auto false_constant2 = ConstantR1(&false_builder, {54.4f, 58.4f}); + auto false_constant3 = ConstantR1(&false_builder, {62.1f, 67.4f}); + auto false_constant4 = ConstantR0(&false_builder, 9.3f); + Tuple(&false_builder, + {Tuple(&false_builder, {false_constant1, false_constant2}), + Tuple(&false_builder, {false_constant3, false_constant4})}); } auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operands = builder.Tuple({}); - builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), - operands, false_builder_result.ConsumeValueOrDie()); + auto pred = ConstantR0(&builder, false); + auto operands = Tuple(&builder, {}); + Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, + false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, @@ -464,8 +465,8 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { CreateR0Parameter(56.3f, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR0Parameter(12.7f, 2, "operand2", &builder, &operand2); - builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, - CreateR0FloorComputation()); + Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0( &builder, 57.0f, @@ -484,8 +485,8 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { &builder, &operand1); auto operand2_param = CreateR1Parameter({10.2f, 11.6f}, 2, "operand2", &builder, &operand2); - builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2, - CreateR1FloorComputation()); + Conditional(pred, operand1, CreateR1CeilComputation(), operand2, + CreateR1FloorComputation()); ComputeAndCompareR1( &builder, {10.0f, 11.0f}, @@ -499,27 +500,25 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); - auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); - auto pred_cond = inner_builder.GetTupleElement(param0, 0); - auto true_operand = inner_builder.GetTupleElement(param0, 1); - auto false_operand = inner_builder.GetTupleElement(param0, 2); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0CeilComputation(), false_operand, - CreateR0FloorComputation()); + auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0"); + auto pred_cond = GetTupleElement(param0, 0); + auto true_operand = GetTupleElement(param0, 1); + auto false_operand = GetTupleElement(param0, 2); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); } auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred1 = builder.ConstantR0(true); - auto pred2 = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(1.1f); - auto operand2 = builder.ConstantR0(12.2f); - auto operand3 = builder.ConstantR0(43.3f); - auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); - builder.Conditional(pred1, tuple_operand, - inner_builder_result.ConsumeValueOrDie(), operand3, - CreateR0IdentityComputation()); + auto pred1 = ConstantR0(&builder, true); + auto pred2 = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 1.1f); + auto operand2 = ConstantR0(&builder, 12.2f); + auto operand3 = ConstantR0(&builder, 43.3f); + auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2}); + Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(), + operand3, CreateR0IdentityComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -529,23 +528,22 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); - auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); - auto pred_cond = inner_builder.GetTupleElement(param0, 0); - auto true_operand = inner_builder.GetTupleElement(param0, 1); - auto false_operand = inner_builder.GetTupleElement(param0, 2); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0CeilComputation(), false_operand, - CreateR0FloorComputation()); + auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0"); + auto pred_cond = GetTupleElement(param0, 0); + auto true_operand = GetTupleElement(param0, 1); + auto false_operand = GetTupleElement(param0, 2); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); } auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred2 = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(1.1f); - auto operand2 = builder.ConstantR0(12.2f); - auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); - builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); + auto pred2 = ConstantR0(&builder, false); + auto operand1 = ConstantR0(&builder, 1.1f); + auto operand2 = ConstantR0(&builder, 12.2f); + auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2}); + Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -553,12 +551,12 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { // Test a mismatch in the shape of the true operand and true computation. XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR0TupleSubComputation()); + auto pred = ConstantR0(&builder, true); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR0TupleSubComputation()); auto result = builder.Build(); EXPECT_FALSE(result.ok()); @@ -572,40 +570,40 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { XlaComputation swapper; { XlaBuilder builder(TestName() + ".swapper"); - auto param0 = builder.Parameter(0, tuple_shape, "sp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - builder.Tuple({y, x}); + auto param0 = Parameter(&builder, 0, tuple_shape, "sp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + Tuple(&builder, {y, x}); swapper = builder.Build().ConsumeValueOrDie(); } XlaComputation forwarder; { XlaBuilder builder(TestName() + ".forwarder"); - auto param0 = builder.Parameter(0, tuple_shape, "fp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - builder.Tuple({x, y}); + auto param0 = Parameter(&builder, 0, tuple_shape, "fp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + Tuple(&builder, {x, y}); forwarder = builder.Build().ConsumeValueOrDie(); } XlaComputation main; { XlaBuilder builder(TestName() + ".main"); - auto param0 = builder.Parameter(0, tuple_shape, "mp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - auto lt_pred = builder.Lt(x, y); - auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper); - auto ge_pred = builder.Ge(x, y); - builder.Conditional(ge_pred, res, swapper, res, forwarder); + auto param0 = Parameter(&builder, 0, tuple_shape, "mp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + auto lt_pred = Lt(x, y); + auto res = Conditional(lt_pred, param0, forwarder, param0, swapper); + auto ge_pred = Ge(x, y); + Conditional(ge_pred, res, swapper, res, forwarder); main = builder.Build().ConsumeValueOrDie(); } auto test_swap = [&](float a, float b) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(a); - auto y = builder.ConstantR0(b); - auto tuple_operand = builder.Tuple({x, y}); - builder.Call(main, {tuple_operand}); + auto x = ConstantR0(&builder, a); + auto y = ConstantR0(&builder, b); + auto tuple_operand = Tuple(&builder, {x, y}); + Call(&builder, main, {tuple_operand}); ComputeAndCompareTuple( &builder, diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 916ffadbc798ec0dd016f45b0bc4c36233455ee7..cc5d3b11767457444d4c199943e689f082d5b199 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,7 +40,7 @@ class ConstantsTest : public ClientLibraryTestBase { TEST_F(ConstantsTest, ZeroCellF32) { XlaBuilder builder(TestName()); - builder.ConstantR1({}); + ConstantR1(&builder, {}); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -48,7 +49,7 @@ TEST_F(ConstantsTest, OneCellF32) { std::vector constant = {2.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } @@ -57,7 +58,7 @@ TEST_F(ConstantsTest, OneCellS32) { std::vector constant = {2}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}); } @@ -66,7 +67,7 @@ TEST_F(ConstantsTest, OneCellU32) { std::vector constant = {2}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}); } @@ -75,7 +76,7 @@ TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } @@ -85,14 +86,14 @@ TEST_F(ConstantsTest, SixteenCells) { 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_0x2) { XlaBuilder builder(TestName()); - builder.ConstantR2FromArray2D(Array2D(0, 2)); + ConstantR2FromArray2D(&builder, Array2D(0, 2)); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); } @@ -102,15 +103,15 @@ TEST_F(ConstantsTest, Small_2x2) { MakeLinspaceArray2D(100.0, 200.0, 2, 2); XlaBuilder builder(TestName()); - builder.ConstantR2FromArray2D(*constant); + ConstantR2FromArray2D(&builder, *constant); ComputeAndCompareR2(&builder, *constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - auto constant = builder.ConstantLiteral( - *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); + ConstantLiteral( + &builder, *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -125,8 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - auto constant = - builder.ConstantLiteral(*Literal::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, *Literal::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -145,13 +145,13 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { { XlaBuilder builder(TestName()); - builder.ConstantLiteral(*input_literal); + ConstantLiteral(&builder, *input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } { XlaBuilder builder(TestName()); - builder.ConstantR4FromArray4D(input_array); + ConstantR4FromArray4D(&builder, input_array); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } } @@ -159,9 +159,9 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - builder.ConstantLiteral( - *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, *Literal::MakeTuple( + {Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); @@ -172,5 +172,13 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } +TEST_F(ConstantsTest, Token) { + XlaBuilder builder(TestName()); + ConstantLiteral(&builder, *Literal::CreateToken()); + // TODO(b/80000000): tokens cannot be returned from computations. + Tuple(&builder, {}); + TF_ASSERT_OK(Execute(&builder, {}).status()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 722d882471a41a75c1e5e60f8c1a151b76c7e004..292942a49e2f0c4b077dc71c9d0e730909689e3a 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -45,8 +45,8 @@ class ConvertTest : public ClientLibraryTestBase { TEST_F(ConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -54,8 +54,8 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { TEST_F(ConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -63,8 +63,8 @@ TEST_F(ConvertTest, ConvertR1F32ToR1F32) { TEST_F(ConvertTest, ConvertR1S32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -72,8 +72,8 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { TEST_F(ConvertTest, ConvertR1PREDToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, true}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, S32); std::vector expected = {1, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); @@ -81,8 +81,8 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, true}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, F32); std::vector expected = {1., 0., 1.}; ComputeAndCompareR1(&builder, expected, {}); @@ -90,8 +90,8 @@ TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {}); + ConvertElementType(a, F32); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -99,8 +99,8 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { TEST_F(ConvertTest, ConvertR1F32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.6, 64.4}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {42.6, 64.4}); + ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -146,11 +146,11 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000010000000000LL), }; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, F32); + ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -165,11 +165,11 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, F32); + ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -183,11 +183,11 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, U32); + ConvertElementType(arg_param, U32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -200,11 +200,11 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -217,11 +217,11 @@ XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -254,11 +254,11 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { -9223371487098961920.f, -9223370937343148032.f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -269,8 +269,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, F32); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -278,8 +278,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, S32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -287,8 +287,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, U32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, U32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -296,8 +296,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32.0f, 64.0f}); - builder.ConvertElementType(a, F64); + auto a = ConstantR1(&builder, {32.0f, 64.0f}); + ConvertElementType(a, F64); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -305,8 +305,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32.0, 64.0}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {32.0, 64.0}); + ConvertElementType(a, F32); std::vector expected = {32.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -314,9 +314,9 @@ XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { TEST_F(ConvertTest, ConvertS32Extremes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {std::numeric_limits::min(), std::numeric_limits::max()}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::max()}); + ConvertElementType(a, F32); std::vector expected = { static_cast(std::numeric_limits::min()), @@ -327,10 +327,10 @@ TEST_F(ConvertTest, ConvertS32Extremes) { TEST_F(ConvertTest, ConvertMapToS32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); - b->ConvertElementType(param, S32); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in"); + ConvertElementType(param, S32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -339,10 +339,10 @@ TEST_F(ConvertTest, ConvertMapToS32) { TEST_F(ConvertTest, ConvertMapToF32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); - b->ConvertElementType(param, F32); - auto a = builder.ConstantR1({42, 64}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in"); + ConvertElementType(param, F32); + auto a = ConstantR1(&builder, {42, 64}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -355,9 +355,9 @@ TEST_F(ConvertTest, ConvertMapToF32) { // the new convert should have the same element type as the old convert. TEST_F(ConvertTest, ConvertReshape) { XlaBuilder builder(TestName()); - auto input = builder.ConstantR1({42}); - auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); - builder.ConvertElementType(reshape, F32); + auto input = ConstantR1(&builder, {42}); + auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + ConvertElementType(reshape, F32); ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); } @@ -394,10 +394,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { client_->TransferToServer(*Literal::CreateR1(input))); XlaBuilder builder(TestName()); - builder.ConvertElementType( - builder.Parameter( - 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), - "param"), + ConvertElementType( + Parameter(&builder, 0, + ShapeUtil::MakeShape(F16, {static_cast(input.size())}), + "param"), F32); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); @@ -414,10 +414,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { client_->TransferToServer(*Literal::CreateR1(input))); XlaBuilder builder(TestName()); - builder.ConvertElementType( - builder.Parameter( - 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), - "param"), + ConvertElementType( + Parameter(&builder, 0, + ShapeUtil::MakeShape(F32, {static_cast(input.size())}), + "param"), F16); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); @@ -426,28 +426,28 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { XLA_TEST_F(ConvertTest, ConvertC64ToC64) { XlaBuilder builder(TestName()); std::vector x = {{42.0f, 64.0f}}; - builder.ConvertElementType(builder.ConstantR1(x), C64); + ConvertElementType(ConstantR1(&builder, x), C64); ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConvertTest, ConvertS64S64) { XlaBuilder builder(TestName()); std::vector x = {{-42, 64}}; - builder.ConvertElementType(builder.ConstantR1(x), S64); + ConvertElementType(ConstantR1(&builder, x), S64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64U64) { XlaBuilder builder(TestName()); std::vector x = {{42, 64}}; - builder.ConvertElementType(builder.ConstantR1(x), U64); + ConvertElementType(ConstantR1(&builder, x), U64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64S64) { XlaBuilder builder(TestName()); std::vector unsigned_x = {{42, UINT64_MAX}}; - builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); + ConvertElementType(ConstantR1(&builder, unsigned_x), S64); std::vector signed_x = {{42, -1}}; ComputeAndCompareR1(&builder, signed_x, {}); } @@ -455,11 +455,31 @@ XLA_TEST_F(ConvertTest, ConvertU64S64) { XLA_TEST_F(ConvertTest, ConvertS64U64) { XlaBuilder builder(TestName()); std::vector signed_x = {{42, -1, INT64_MIN}}; - builder.ConvertElementType(builder.ConstantR1(signed_x), U64); + ConvertElementType(ConstantR1(&builder, signed_x), U64); std::vector unsigned_x = { {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; ComputeAndCompareR1(&builder, unsigned_x, {}); } +XLA_TEST_F(ConvertTest, ConvertBF16F32) { + XlaBuilder builder(TestName()); + + std::vector all_bfloats(1 << 16); + for (int i = 0; i < all_bfloats.size(); ++i) { + all_bfloats[i].value = i; + } + + std::vector expected(all_bfloats.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = (1U << 16) * i; + } + + // Exhaustively test all bf16 to f32 conversions. + xla::XlaOp all_bfloats_bf16 = ConstantR1(&builder, all_bfloats); + xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32); + BitcastConvertType(all_bfloats_f32, U32); + ComputeAndCompareR1(&builder, expected, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index b5a42e305987df030c15d089f5877f73bb61de1b..7605ebf4c0eacd7f44e867e23dbc27c6c1bc3e93 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -97,10 +97,10 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, .ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(*input_array); + auto input = ConstantR4FromArray4D(&builder, *input_array); auto weight = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); - auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); + auto conv1 = Conv(input, weight, {1, 1}, Padding::kValid); ConvolutionDimensionNumbers dim_nums = XlaBuilder::CreateDefaultConvDimensionNumbers(); @@ -117,8 +117,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, dim_nums.set_kernel_input_feature_dimension( dim_nums.kernel_output_feature_dimension()); dim_nums.set_kernel_output_feature_dimension(old_kernel_input_feature_dim); - builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, - dim_nums); + ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, dim_nums); auto expected_conv1 = ReferenceUtil::ConvArray4D(*input_array, *weight_array, {1, 1}, Padding::kValid); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 346bb3a3996ee5bf662b0f74dd0c2096efbf5295..0f6d54d042dd6af6d82e1eea93a66c2e9be53639 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -89,9 +89,9 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { ASSERT_EQ(2, arhs->height()); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR4FromArray4D(*alhs); - auto rhs = builder.ConstantR4FromArray4D(*arhs); - builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + auto lhs = ConstantR4FromArray4D(&builder, *alhs); + auto rhs = ConstantR4FromArray4D(&builder, *arhs); + Conv(lhs, rhs, {1, 1}, Padding::kValid); ComputeAndCompare(&builder, {}, error_spec_); } @@ -109,9 +109,9 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -140,9 +140,9 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -174,9 +174,9 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -210,9 +210,9 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({{1.0f, 2.0f, 3.0f, 4.0f}, @@ -238,9 +238,9 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1}, Padding::kValid); } Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); @@ -268,10 +268,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { { Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -304,10 +304,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -335,10 +335,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -369,10 +369,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { { Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -408,8 +408,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 3D convolution. ConvolutionDimensionNumbers dnums; @@ -429,8 +429,7 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { dnums.set_kernel_input_feature_dimension(3); dnums.set_kernel_output_feature_dimension(4); - builder.ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); @@ -475,8 +474,8 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 2D convolution. ConvolutionDimensionNumbers dnums; @@ -493,8 +492,7 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(3); - builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); @@ -541,8 +539,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); ConvolutionDimensionNumbers dnums; dnums.set_input_feature_dimension(0); @@ -551,7 +549,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, dnums.set_kernel_output_feature_dimension(1); dnums.set_output_batch_dimension(0); dnums.set_output_feature_dimension(1); - builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); + ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); Array2D param0(4, 29); param0.FillUnique(); @@ -599,8 +597,8 @@ class Convolve1D1WindowTestBase Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 1D convolution. ConvolutionDimensionNumbers dnums; @@ -614,8 +612,7 @@ class Convolve1D1WindowTestBase dnums.set_kernel_input_feature_dimension(1); dnums.set_kernel_output_feature_dimension(2); - builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape), @@ -726,9 +723,9 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -754,9 +751,9 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillIota(0); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index fea850dc135e33fe098aa755c6fdd93319cd2837..c31d033bb0f0e52d40251c4d7b64d52f42d29dc6 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -55,12 +55,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) { XlaBuilder builder(TestName()); const Array4D input_array(1, 1, 1, 1, {2}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {3}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(1, 1, 1, 1, {6}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -70,12 +70,12 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { XlaBuilder builder(TestName()); const Array4D input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {2}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(5, 1, 1, 1, {2, 4, 6, 8, 10}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -86,12 +86,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { Array4D input_array(2, 1, 3, 4); input_array.FillWithMultiples(1); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {2.3}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(2, 1, 3, 4); expected.FillWithMultiples(2.3); @@ -102,12 +102,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 1, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 3, 1, 1, {12, 34, 56}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -117,12 +117,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 2, {1, 2}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {12}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -132,12 +132,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {12, 23}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -147,12 +147,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 1, {12, 34}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -162,12 +162,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 1, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {13, 24}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -177,12 +177,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 2, {1000, 100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {1234}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -194,13 +194,13 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { Array4D input_array( 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 0, 0}); // plane 1 - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array( 2, 2, 1, 2, {1000, 100, 10, 1, 0.1, 0.01, 0.001, 0.0001}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected( 2, 2, 2, 2, @@ -213,12 +213,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {10, 30}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -228,12 +228,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 3, {10, 30, 50}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -243,12 +243,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 1, {123}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -258,12 +258,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {123, 345}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -273,12 +273,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {2, 2}, Padding::kValid); + Conv(input, filter, {2, 2}, Padding::kValid); Array4D expected(1, 1, 2, 2, {10, 30, 70, 90}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -288,12 +288,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 1, {1}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {10, 20, 30}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 1, {20}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -303,12 +303,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 5, {10000, 1000, 100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 3, {123, 1230, 12300}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -318,15 +318,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0 0, 100, 0, // row 1 10, 0, 1}); // row 2 - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 2, 2, {104, 230, 2300, 10400}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -336,12 +336,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 2, 1, 1, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 2, {13, 24}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -351,12 +351,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 2, {7, 13, 17, 23}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 2, {216, 276, 396, 456}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -366,12 +366,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {7, 13}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {33, 53}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -383,15 +383,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { std::vector input_data(64); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(1, 1, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(128); std::fill(filter_data.begin(), filter_data.begin() + 64, 1.0); std::fill(filter_data.begin() + 64, filter_data.begin() + 128, 2.0); const Array4D filter_array(2, 1, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {2016, 4032}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -403,14 +403,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { std::vector input_data(16 * 1 * 1 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(16, 1, 1, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; @@ -432,14 +432,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * ky * kx); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, ky, kx, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data(bs); for (int i = 0; i < bs; ++i) { @@ -463,14 +463,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * ky * kx); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, ky, kx, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 23, @@ -492,14 +492,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 8 * 8); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 19664, 21744, 23824, 25904, 27984, 30064, 32144, 34224, @@ -515,7 +515,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { std::vector input_data(2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(1, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -527,9 +527,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {14240, 30496}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -541,7 +541,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { std::vector input_data(2 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(2, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -553,9 +553,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(2, 2, 1, 1, {14240, 30496, 38816, 87840}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -567,7 +567,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { std::vector input_data(32 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(32, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -579,9 +579,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 14240, 30496, 38816, 87840, 63392, 145184, 87968, @@ -613,9 +613,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { } } - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(16, 16, 1, 1); for (int i0 = 0; i0 < 16; ++i0) { @@ -635,9 +635,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { Array4D input_array(1, 1, 4, 6, input_data); Array4D filter_array(1, 1, 2, 3, {1, 10, 100, 2, 20, 200}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -654,9 +654,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -677,9 +677,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { 200, 20, 2, // 300, 30, 3, // 400, 40, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1}, /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2}, /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -699,9 +699,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, -1}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -718,9 +718,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, 2}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -737,9 +737,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {2, -1}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -756,9 +756,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {3, 2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, @@ -781,9 +781,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-3, -2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, @@ -821,9 +821,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -854,9 +854,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -887,9 +887,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -920,9 +920,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -954,9 +954,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -970,12 +970,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 2 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 1.0); Array4D filter_array(1, 2, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -995,7 +995,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests padding sizes that don't correspond either to SAME or VALID padding. - builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); std::vector expected_data = { 0, 0, 0, 0, 0, 0, 0, // @@ -1014,12 +1014,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1039,7 +1039,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests padding sizes that don't correspond either to SAME or VALID padding. - builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); std::vector expected_data = { 0, 0, 0, 0, 0, 0, 0, 0, // @@ -1058,12 +1058,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1083,7 +1083,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests zero padding sizes. This can use matmul for computation. - builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); std::vector expected_data = { 2, 4, 6, // @@ -1099,12 +1099,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { std::vector input_data(1 * 2 * 3 * 2); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 2, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 2 * 3); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 2, 3, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1124,7 +1124,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests zero padding sizes. This can use matmul for computation. - builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); std::vector expected_data = { 12, 15, 18, // @@ -1148,14 +1148,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 2, /*values=*/{5, 6})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 0}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 2, /*values=*/{5, 6})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 0}}); ComputeAndCompareR4(&builder, {{{{5, 16, 27}}}}, {}, error_spec_); } @@ -1167,16 +1167,16 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 1, /*values=*/{1})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvGeneralDilated(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 3}}, - /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvGeneralDilated(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 3}}, + /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, error_spec_); } @@ -1187,14 +1187,14 @@ XLA_TEST_F(ConvolutionVariantsTest, XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 1, /*values=*/{1})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 1}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 1}}); ComputeAndCompareR4(&builder, {{{{10}}}}, {}, error_spec_); } @@ -1208,14 +1208,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 2, /*values=*/{1, 10})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 2}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 2, /*values=*/{1, 10})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 2}}); ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); } @@ -1229,17 +1229,17 @@ XLA_TEST_F(ConvolutionVariantsTest, // weight gradients: 24,130,240 // // This pattern will be fused to backward convolution with padding=(1,2). - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 2}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 2}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); } @@ -1255,17 +1255,17 @@ XLA_TEST_F(ConvolutionVariantsTest, // This pattern will be fused to backward convolution with padding=(2,1). // Note: both (2,1) and (2,0) are valid padding for the backward convolution // because the stride is 2. - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {2, 0}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 0}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, error_spec_); } @@ -1282,17 +1282,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { // because the stride is 2. ConvolutionFolding prefers (2,2) because cuDNN // supports even padding only -- using (2,1) would need extra effort of // canonicalization. - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); } @@ -1300,14 +1300,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR3FromArray3D( - Array3D(1, 1, 1, /*value=*/1)); + auto gradients = ConstantR3FromArray3D( + &builder, Array3D(1, 1, 1, /*value=*/1)); auto weights = - builder.ConstantR3FromArray3D(Array3D({{{1, 10, 100}}})); - auto mirrored_weights = builder.Rev(weights, {2}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1}, - /*padding=*/{{1, 1}}); + ConstantR3FromArray3D(&builder, Array3D({{{1, 10, 100}}})); + auto mirrored_weights = Rev(weights, {2}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1}, + /*padding=*/{{1, 1}}); ComputeAndCompareR3(&builder, {{{10}}}, {}, error_spec_); } @@ -1315,17 +1315,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { XlaBuilder builder(TestName()); auto activations = - builder.ConstantR3FromArray3D(Array3D({{{1, 2, 3, 4}}})); + ConstantR3FromArray3D(&builder, Array3D({{{1, 2, 3, 4}}})); auto gradients = - builder.ConstantR3FromArray3D(Array3D({{{100, 10, 1}}})); + ConstantR3FromArray3D(&builder, Array3D({{{100, 10, 1}}})); auto forward_conv = - builder.ConvGeneralDilated(activations, gradients, - /*window_strides=*/{1}, - /*padding=*/{{2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, - XlaBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/1)); - builder.Transpose(forward_conv, {0, 1, 2}); + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1}, + /*padding=*/{{2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/1)); + Transpose(forward_conv, {0, 1, 2}); ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, error_spec_); } @@ -1336,21 +1336,21 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { auto gradients_flat = Literal::CreateR1({1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = builder.ConstantLiteral(*gradients_literal); + auto gradients = ConstantLiteral(&builder, *gradients_literal); auto weights_flat = Literal::CreateR1({1, 10, 100}); auto weights_literal = weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = builder.ConstantLiteral(*weights_literal); + auto weights = ConstantLiteral(&builder, *weights_literal); auto expected_flat = Literal::CreateR1({10}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto mirrored_weights = builder.Rev(weights, {2, 3, 4}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1, 1}, - /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); + auto mirrored_weights = Rev(weights, {2, 3, 4}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); } @@ -1360,25 +1360,25 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); auto activations_literal = activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = builder.ConstantLiteral(*activations_literal); + auto activations = ConstantLiteral(&builder, *activations_literal); auto gradients_flat = Literal::CreateR1({100, 10, 1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = builder.ConstantLiteral(*gradients_literal); + auto gradients = ConstantLiteral(&builder, *gradients_literal); auto expected_flat = Literal::CreateR1({13, 24, 130}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1, 1}, - /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/3)); - builder.Transpose(forward_conv, {0, 1, 2, 3, 4}); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/3)); + Transpose(forward_conv, {0, 1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 2b3390ca98cb2922410d451c06811aa9d4ff8c0b..fef42885e516fa8c8f87756d7a953fe5f37a630f 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -248,7 +248,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { auto empty = Literal::CreateFromShape(in_shape); XlaBuilder builder(TestName()); - auto param0 = builder.Parameter(0, in_shape, "input"); + Parameter(&builder, 0, in_shape, "input"); auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index b43d5c9ff5d75ee0e1b3c9ceb2bc295e631ac107..d1516a28b0bb3857d9aee0922a252e25a8f9d2d5 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" @@ -135,8 +136,8 @@ class CustomCallClientAPITest : public ClientLibraryTestBase {}; // are reserved for internal use. XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { XlaBuilder builder(TestName()); - builder.CustomCall("$illegal", /*operands=*/{}, - ShapeUtil::MakeShape(F32, {1})); + CustomCall(&builder, "$illegal", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {1})); StatusOr> result = Execute(&builder, /*arguments=*/{}); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index bfe688e20d182d581c3e3b545ac2289413deef7c..d4b3aac85bff283515088f6e61c9d2bad11f60d3 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -48,7 +48,7 @@ class DeallocationTest : public ClientLibraryTestBase { TEST_F(DeallocationTest, DeallocateScalar) { XlaBuilder builder(TestName()); - builder.ConstantR0(42.0); + ConstantR0(&builder, 42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); // A result can be transferred an arbitrary number of times. Add an extra @@ -66,7 +66,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { TEST_F(DeallocationTest, DeallocateVector) { XlaBuilder builder(TestName()); - builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -79,7 +79,7 @@ TEST_F(DeallocationTest, DeallocateVector) { TEST_F(DeallocationTest, DeallocateEmptyVector) { XlaBuilder builder(TestName()); - builder.ConstantR1({}); + ConstantR1(&builder, {}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -92,8 +92,8 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { XLA_TEST_F(DeallocationTest, DeallocateTuple) { XlaBuilder builder(TestName()); - builder.Tuple({builder.ConstantR0(42.0), - builder.ConstantR1({1.0, 2.0, 3.0})}); + Tuple(&builder, {ConstantR0(&builder, 42.0), + ConstantR1(&builder, {1.0, 2.0, 3.0})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -106,9 +106,10 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { XlaBuilder builder(TestName()); - auto element = builder.ConstantR0(42.0); - auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), element}); - builder.Tuple({element, inner_tuple, element}); + auto element = ConstantR0(&builder, 42.0); + auto inner_tuple = + Tuple(&builder, {ConstantR0(&builder, 42.0), element}); + Tuple(&builder, {element, inner_tuple, element}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -122,9 +123,9 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { XlaBuilder builder(TestName()); auto inner_tuple = - builder.Tuple({builder.ConstantR0(42.0), - builder.ConstantR1({1.0, 2.0, 3.0})}); - builder.Tuple({inner_tuple, builder.ConstantR1({0.123, 0.456})}); + Tuple(&builder, {ConstantR0(&builder, 42.0), + ConstantR1(&builder, {1.0, 2.0, 3.0})}); + Tuple(&builder, {inner_tuple, ConstantR1(&builder, {0.123, 0.456})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 12789fe66530fe03eb33316eda652336f29971ab..acba67491d25007ab774530fd7ca236a4363b6f0 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -54,9 +54,9 @@ class DeconstructTupleTest : public ClientLibraryTestBase { TEST_F(DeconstructTupleTest, DeconstructTuple) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -73,9 +73,9 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status1 = client_->DeconstructTuple(*global_data); @@ -103,9 +103,9 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2, const2, const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2, const2, const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -129,9 +129,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2, const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2, const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -159,7 +159,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XlaBuilder builder(TestName()); - builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -174,8 +174,8 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); - builder.Tuple({p}); + auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); + Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); auto result_status = client_->DeconstructTuple(*global_data); @@ -186,9 +186,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({builder.Tuple({const1, const2}), const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {Tuple(&builder, {const1, const2}), const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 085a5105aca1c173a7cbc211aebbeb5b254b0753..810947ab01b69b10b6ae60c551bd7aba10a6313d 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -30,7 +30,7 @@ TEST_F(ClientLibraryTestBase, DeepGraph) { auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); XlaOp z = x; for (int i = 0; i < kDepth; ++i) { - z = b.Add(z, y); + z = Add(z, y); } ComputeAndCompareR0(&b, /*expected=*/kDepth + 3, {x_data.get(), y_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0fd846cef8095a857dd7b2c12d8afdf409e2bd66..cf2e645d472efab9ca649dbde6602fd4f205d924 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -70,9 +70,9 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), Literal::CreateR2({{5, 6}, {7, 8}}).get()}), "arg0", &builder, ¶m); - auto lhs = builder.GetTupleElement(param, 0); - auto rhs = builder.GetTupleElement(param, 1); - builder.Dot(lhs, rhs); + auto lhs = GetTupleElement(param, 0); + auto rhs = GetTupleElement(param, 1); + Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, *Literal::CreateR2({{19, 22}, {43, 50}}), @@ -87,9 +87,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(0.0), {}, this->error_spec_); @@ -102,9 +102,9 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); - auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, {{3.0f, 4.0f}}); + auto rhs = ConstantFromArray(&builder, {3.0f, 4.0f}); + Dot(lhs, rhs); this->template ComputeAndCompareR1(&builder, {static_cast(25.0f)}, {}, this->error_spec_); @@ -113,9 +113,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR1({static_cast(2.0f)}); - auto rhs = builder.ConstantR1({static_cast(3.0f)}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR1(&builder, {static_cast(2.0f)}); + auto rhs = ConstantR1(&builder, {static_cast(3.0f)}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(6.0f), {}, this->error_spec_); @@ -124,9 +124,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); - auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantFromArray(&builder, {1.0f, 2.5f, 42.0f}); + auto rhs = ConstantFromArray(&builder, {11.0f, -1.0f, 0.5f}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(29.5f), {}, this->error_spec_); @@ -139,9 +139,9 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(0, 0), {}, this->error_spec_); @@ -150,10 +150,10 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D( - {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + auto rhs = ConstantR2FromArray2D( + &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(0, 3), {}, this->error_spec_); @@ -162,10 +162,10 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D( - {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D( + &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(3, 0), {}, this->error_spec_); @@ -174,9 +174,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + Dot(lhs, rhs); this->template ComputeAndCompareR2( &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); @@ -186,11 +186,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto param0 = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); auto param1 = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); - auto exp0 = builder.Exp(param0); - auto result = builder.Dot(exp0, param1); + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); + auto exp0 = Exp(param0); + Dot(exp0, param1); auto lhs_handle = this->client_ @@ -231,9 +231,8 @@ class SquareMatrixDot : public DotOperationTest { .ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); Array2D expected({{15.0f, -2.0f}, {-25.0f, 34.0f}}); ComputeAndCompareR2(&builder, expected, @@ -316,26 +315,26 @@ void ParametricDotTest::TestImpl() { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.m, param.k}, - MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), - "dot_lhs"), - builder.Parameter(1, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.k, param.n}, - MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), - "dot_rhs")); + auto result = + Dot(Parameter(&builder, 0, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.k}, + MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), + "dot_lhs"), + Parameter(&builder, 1, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.k, param.n}, + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), + "dot_rhs")); if (param.has_addend) { - result = builder.Add( - result, builder.Parameter( - 2, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.m, param.n}, - MinorToMajorForIsRowMajor(param.addend_row_major)), - "addend")); + result = + Add(result, + Parameter(&builder, 2, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.n}, + MinorToMajorForIsRowMajor(param.addend_row_major)), + "addend")); } std::unique_ptr> expected; @@ -492,9 +491,8 @@ class NonsquareMatrixDot : public DotOperationTest { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); @@ -524,9 +522,8 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); Array2D expected({{30.0, -2.0}}); @@ -538,11 +535,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); - auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); - auto matrix12 = builder.Dot(matrix1, matrix2); - auto matrix21 = builder.Dot(matrix2, matrix1); - builder.Add(matrix12, matrix21); + auto matrix1 = + ConstantR2FromArray2D(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto matrix2 = + ConstantR2FromArray2D(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto matrix12 = Dot(matrix1, matrix2); + auto matrix21 = Dot(matrix2, matrix1); + Add(matrix12, matrix21); Array2D expected({{42.0f, 56.0f}, {74.0f, 96.0f}}); this->template ComputeAndCompareR2(&builder, expected, {}, @@ -559,29 +558,29 @@ TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto x = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); - auto y = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "y"); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "y"); - auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); - auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); + auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); + auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); // Slice batches into individual matrices and multiply them. std::vector out_slices; for (int i = 0; i < 4; ++i) { // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); - x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2}); - auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); - y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2}); + auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); + x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2}); + auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); + y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2}); - auto out = builder.Dot(x_slice, y_slice); - out = builder.Reshape(out, {0, 1}, {1, 2, 2}); + auto out = Dot(x_slice, y_slice); + out = Reshape(out, {0, 1}, {1, 2, 2}); out_slices.push_back(out); } - auto out_flat = builder.ConcatInDim(out_slices, 0); - builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); + auto out_flat = ConcatInDim(&builder, out_slices, 0); + Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ ->TransferToServer(*Literal::CreateR4FromArray4D( @@ -616,9 +615,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { XlaBuilder builder(this->TestName()); auto x = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); auto y = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -626,7 +625,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); - auto out = builder.DotGeneral(x, y, dnums); + DotGeneral(x, y, dnums); auto x_data = this->client_ @@ -678,19 +677,21 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { XlaBuilder builder(this->TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto lhs_arg = builder.Parameter( - 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), + auto lhs_arg = Parameter( + &builder, 0, + ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), "lhs"); - auto rhs_arg = builder.Parameter( - 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), + auto rhs_arg = Parameter( + &builder, 1, + ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), "rhs"); if (transpose_lhs) { - lhs_arg = builder.Transpose(lhs_arg, {1, 0}); + lhs_arg = Transpose(lhs_arg, {1, 0}); } if (transpose_rhs) { - rhs_arg = builder.Transpose(rhs_arg, {1, 0}); + rhs_arg = Transpose(rhs_arg, {1, 0}); } - auto result = builder.Dot(lhs_arg, rhs_arg); + Dot(lhs_arg, rhs_arg); Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " @@ -713,15 +714,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); XlaBuilder builder(this->TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), - "rhs_arg_0"); - auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), - "rhs_arg_1"); - auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), - "rhs_arg_2"); - auto result = builder.Dot( - lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_arg_0 = Parameter( + &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); + auto rhs_arg_1 = Parameter( + &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1"); + auto rhs_arg_2 = Parameter( + &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2"); + Dot(lhs_constant, + ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); std::unique_ptr> arg_0_value_array( new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -761,15 +762,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {2.0f, 1.0f}})); XlaBuilder builder(this->TestName()); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), - "lhs_arg_0"); - auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 3}), - "lhs_arg_1"); - auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType({2, 1}), - "lhs_arg_2"); - auto result = builder.Dot( - builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto lhs_arg_0 = Parameter( + &builder, 0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); + auto lhs_arg_1 = Parameter( + &builder, 1, ShapeUtil::MakeShapeWithType({2, 3}), "lhs_arg_1"); + auto lhs_arg_2 = Parameter( + &builder, 2, ShapeUtil::MakeShapeWithType({2, 1}), "lhs_arg_2"); + Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), + rhs_constant); std::unique_ptr> arg_0_value_array( new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -811,16 +812,15 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{96.0, 105.0, 114.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -839,25 +839,23 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{105.0}, {105.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSReverseMM)))) { + + DotOfGatherOptimizationWithConstRHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -870,25 +868,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{105.0, 105.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSReverseMM)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -901,25 +895,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{96.0}, {105.0}, {114.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -937,25 +927,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{126.0, 129.0, 132.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -973,25 +959,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{129.0}, {129.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1001,25 +983,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{56.0, 168.0, 91.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1029,19 +1007,41 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{168.0}, {168.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { + XlaBuilder builder(TestName()); + + Array2D lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array); + + Array2D rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + DotGeneral(lhs_constant, rhs_constant, dot_dnums); + + Array2D expected({ + {26.f, 30.f}, + {38.f, 44.f}, + }); + + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 49f3a10d227f2f9edfe76405ba13498fe822f8d8..f3c258a4d4c446c465320ac16ef7c72e299a51a8 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -138,8 +138,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -164,8 +164,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -190,8 +190,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -367,9 +367,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_value); - auto update = builder.ConstantLiteral(update_value); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_value); + auto update = ConstantLiteral(&builder, update_value); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); } @@ -398,9 +398,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -429,9 +429,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -460,9 +460,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -508,8 +508,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); - auto starts = builder.ConstantR1({index, 0, 0}); - builder.DynamicUpdateSlice(input, update, starts); + auto starts = ConstantR1(&builder, {index, 0, 0}); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareR3(&builder, expected_values, @@ -698,14 +698,14 @@ void BM_DynamicSlice(int num_iters) { auto input_literal = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = builder.ConstantLiteral(*input_literal); + auto input = ConstantLiteral(&builder, *input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); auto start_indices = - builder.Parameter(0, start_indices_shape, "start_indices"); + Parameter(&builder, 0, start_indices_shape, "start_indices"); // Add DynamicSlice op to the computatation. - builder.DynamicSlice(input, start_indices, {1, 1, 1, 1}); + DynamicSlice(input, start_indices, {1, 1, 1, 1}); auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. @@ -716,8 +716,10 @@ void BM_DynamicSlice(int num_iters) { .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, buffer)); + stream.get(), *start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index a6ba6db5d3bf86de91f6fda022c46afee01281c2..ddc6a7db18760bf951023f0a684d78739f3e869d 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -34,7 +34,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); - b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); + Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); TF_ASSERT_OK_AND_ASSIGN(XlaComputation dot_product, b.Build()); ExecutionProfile execution_profile; diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 0a37e4d423620122f2e109343a86a964f46d778f..74cf8b213e0a03394c84008e7a2919e1a5bf1af2 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -54,7 +54,7 @@ class ExhaustiveF32ElementwiseOpTest TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; @@ -79,8 +79,8 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #endif ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Log(input); }, - std::log, known_incorrect_range); + [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log, + known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { @@ -95,14 +95,14 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { #endif ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Exp(input); }, - std::exp, known_incorrect_range); + [](XlaBuilder* builder, const XlaOp& input) { Exp(input); }, std::exp, + known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Tanh(input); }, - std::tanh, /*known_incorrect_range=*/{0, 0}); + [](XlaBuilder* builder, const XlaOp& input) { Tanh(input); }, std::tanh, + /*known_incorrect_range=*/{0, 0}); } std::vector> CreateExhaustiveParameters() { diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 71eb914a8e5eaef2e38b9e6e7d45b8a10ce1bd7a..30dc639f117b9871238f0bf1628502cf8bef2e0c 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -42,12 +42,12 @@ class FloorCeilTest : public ClientLibraryTestBase { LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") << "}"; XlaBuilder builder(TestName()); - auto c = builder.ConstantR1(input); + auto c = ConstantR1(&builder, input); if (f == kCeil) { - builder.Ceil(c); + Ceil(c); } else { ASSERT_EQ(kFloor, f); - builder.Floor(c); + Floor(c); } ComputeAndCompareR1(&builder, expected, /*arguments=*/{}); } @@ -55,12 +55,12 @@ class FloorCeilTest : public ClientLibraryTestBase { void TestR0F32(float input, float expected, Function f) { LOG(INFO) << "input: " << expected; XlaBuilder builder(TestName()); - auto c = builder.ConstantR0(input); + auto c = ConstantR0(&builder, input); if (f == kCeil) { - builder.Ceil(c); + Ceil(c); } else { ASSERT_EQ(kFloor, f); - builder.Floor(c); + Floor(c); } ComputeAndCompareR0(&builder, expected, /*arguments=*/{}); } diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index 73f029b59bc56aa6c3e86200a49fcae0fd177101..0254ae1baaa864b38c3b217a5c2026d34b7f7d12 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -28,11 +28,11 @@ class FmaxSimpleTest : public ClientLibraryTestBase {}; TEST_F(FmaxSimpleTest, FmaxTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = builder.ConstantR1( - {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index e6f79b5ac55dddfbb213a36cadbee53bc9443d9d..ab470f16a32c2363e88a11a9f7d564dcf2981f42 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -557,8 +557,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { *ExecuteAndTransfer(std::move(hlo_module), {}))); } -// TODO(b/64070202): Investigate failure. -XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { +XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -793,14 +792,14 @@ void BM_ParallelFusion(int num_iters) { // Create computation. XlaBuilder builder("ParallelFusion"); Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); - auto param0 = builder.Parameter(0, shape0, "param0"); + auto param0 = Parameter(&builder, 0, shape0, "param0"); Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); - auto param1 = builder.Parameter(1, shape1, "param1"); + auto param1 = Parameter(&builder, 1, shape1, "param1"); Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); - auto param2 = builder.Parameter(2, shape2, "param2"); + auto param2 = Parameter(&builder, 2, shape2, "param2"); - auto x = builder.Mul(param0, param1); - auto y = builder.Add(x, param2); + auto x = Mul(param0, param1); + Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); // Transfer literals to device. diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 143ffbdeb409d91ab6d46d386aa5ff98ebc4ae10..b8404826b161b9edbbd260d73c175cce935ace91 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -598,14 +599,14 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); - auto operand = builder.Parameter(0, operand_shape, "operand"); - auto indices = builder.Parameter(1, indices_shape, "indices"); + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); GatherDimensionNumbers dim_numbers; dim_numbers.add_output_window_dims(1); dim_numbers.add_elided_window_dims(0); dim_numbers.add_gather_dims_to_operand_dims(0); dim_numbers.set_index_vector_dim(1); - builder.Gather(operand, indices, dim_numbers, {1, 3}); + Gather(operand, indices, dim_numbers, {1, 3}); std::vector expected = {}; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, @@ -629,8 +630,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, + *result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 76bf47845ca045b4eede9a3b47ae5c2ce93ce577..fd8511884907ae500d8256c3250fe779f8eba83a 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase { static const int kNumElements = 4; }; -using UnaryBuildFuncTy = - std::function; +using UnaryBuildFuncTy = std::function; struct UnaryOpTestParam { std::function compute_func; @@ -62,7 +61,7 @@ XLA_TEST_P(UnaryOpTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}, error_spec_); } @@ -79,18 +78,17 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, ::testing::Values( - UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &XlaBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} + UnaryOpTestParam{[](half x) { return abs(x); }, &Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &Log}, + UnaryOpTestParam{[](half x) { return -x; }, &Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh} )); @@ -118,19 +116,18 @@ XLA_TEST_P(UnaryPredTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}); } INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ - [](half x) { return isfinite(x); }, - &XlaBuilder::IsFinite})); + [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = std::function)>; +using BinaryBuildFuncTy = + std::function)>; struct BinaryOpTestParam { std::function compute_func; @@ -159,7 +156,7 @@ XLA_TEST_P(BinaryOpTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}, error_spec_); @@ -173,22 +170,15 @@ half atan2_imp(half x, half y) { INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( - BinaryOpTestParam{[](half x, half y) { return x + y; }, - &XlaBuilder::Add}, + BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &XlaBuilder::Atan2}, - BinaryOpTestParam{[](half x, half y) { return x / y; }, - &XlaBuilder::Div}, - BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &XlaBuilder::Max}, - BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &XlaBuilder::Min}, - BinaryOpTestParam{[](half x, half y) { return x * y; }, - &XlaBuilder::Mul}, - BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &XlaBuilder::Pow}, - BinaryOpTestParam{[](half x, half y) { return x - y; }, - &XlaBuilder::Sub} + &Atan2}, + BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div}, + BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max}, + BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min}, + BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul}, + BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow}, + BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub} )); @@ -221,27 +211,22 @@ XLA_TEST_P(BinaryPredTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}); } INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, - ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &XlaBuilder::Eq}, - BinaryPredTestParam{[](half x, half y) { return x != y; }, - &XlaBuilder::Ne}, - BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &XlaBuilder::Ge}, - BinaryPredTestParam{[](half x, half y) { return x > y; }, - &XlaBuilder::Gt}, - BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &XlaBuilder::Le}, - BinaryPredTestParam{[](half x, half y) { return x < y; }, - &XlaBuilder::Lt} - - )); + ::testing::Values( + BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq}, + BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne}, + BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge}, + BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt}, + BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le}, + BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt} + + )); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index cf971dd61b71ad329b20b0bb7c16166126562681..4d82442f7e3630c115eff1f17544e2b892c5e7eb 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -30,9 +30,9 @@ class HloMetadataTest : public LocalClientTestBase { } void BuildAddComputation(XlaBuilder* builder) { - auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder->Add(x, y); + auto x = Parameter(builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); } OpMetadata metadata_; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 249da87f489324ed9d377cc46a15cef5a9e74192..9009d67cea6840235d63724ef76d777c8f693d33 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -185,13 +185,9 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT( - param_no, - module->mutable_host_entry_computation_layout()->parameter_count()); - module->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(param_no) - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -199,10 +195,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -210,10 +203,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->Clear(); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->Clear(); } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 22c664d1426c598dbb695ff1b66ce009b0a19c00..ad1f5b9eed8b5b140100c1fa35dc7d698e3db48b 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -72,10 +72,10 @@ HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { return modules_.back().get(); } -void HloVerifiedTestBase::ParseAndVerifyModule( - tensorflow::StringPiece hlo_text) { +void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 5b59cc77f61b05092d3afb331e73932c9edc5840..5b28c01c369fa1ae1c7941f5c8139882c4dbed08 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,7 +44,8 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index f21f83992ffb7c07dff31c68a7e9e3f7944bf512..9191be9fd905ab2e0c661042b042c8233d39e4a1 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -38,9 +38,9 @@ class LocalClientAllocationTest : public LocalClientTestBase { XLA_TEST_F(LocalClientAllocationTest, AddVectors) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = ConstantR1(&builder, {0.0f, 1.0f, 2.0f}); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); @@ -74,9 +74,9 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { // Run a computation on every device on the system. Verify that allocation // occurs on the proper device. XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = ConstantR1(&builder, {0.0f, 1.0f, 2.0f}); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index a366afe8262e1f537b225e395bba9cb2fc22683a..70612e7c49d2815096cc54fd6ae796148249b4db 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -37,8 +37,8 @@ using xla::string; xla::XlaComputation Doubler() { xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto x = builder.Parameter(0, r0f32, "x"); - builder.Mul(x, builder.ConstantR0(2.0)); + auto x = xla::Parameter(&builder, 0, r0f32, "x"); + xla::Mul(x, xla::ConstantR0(&builder, 2.0)); return std::move(builder.Build().ValueOrDie()); } @@ -51,10 +51,10 @@ int main(int argc, char** argv) { xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); - auto opaque_param = builder.Parameter(0, opaque_shape, "x"); + auto opaque_param = Parameter(&builder, 0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(), {sum}); + auto sum = CustomCall(&builder, "SumStructElements", {opaque_param}, r0f32); + Call(&builder, Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 96858c00d6bbe59b673a34e7d5ca261756709596..2c6393794ef1b1558f5e651b5cb7bfa2afa961de 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -54,7 +54,7 @@ class LocalClientExecuteTest : public LocalClientTestBase { XLA_TEST_F(LocalClientExecuteTest, Constant) { XlaBuilder builder(TestName()); - auto y = builder.ConstantR0(123.0f); + ConstantR0(&builder, 123.0f); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); @@ -64,9 +64,9 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { XLA_TEST_F(LocalClientExecuteTest, AddScalars) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.ConstantR0(123.0f); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = ConstantR0(&builder, 123.0f); + Add(x, y); auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); ScopedShapedBuffer result = @@ -77,9 +77,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x"); - auto y = builder.ConstantR1({}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x"); + auto y = ConstantR1(&builder, {}); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); ScopedShapedBuffer result = @@ -90,9 +90,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { XLA_TEST_F(LocalClientExecuteTest, AddVectors) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); @@ -104,9 +104,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); @@ -122,9 +122,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. @@ -155,9 +155,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( @@ -192,9 +192,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { XLA_TEST_F(LocalClientExecuteTest, TupleResult) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Tuple({x, y, x}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Tuple(&builder, {x, y, x}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( @@ -209,21 +209,20 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - auto inner_tuple = builder.Tuple({x, y, x}); - builder.Tuple({inner_tuple, x}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + auto inner_tuple = Tuple(&builder, {x, y, x}); + Tuple(&builder, {inner_tuple, x}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( @@ -238,25 +237,22 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {0, 1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { // Verify setting the result layout of a computation with a tuple output. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Tuple({x, y}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -273,10 +269,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { options, DefaultExecutableRunOptions()); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -291,15 +287,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { // Computation adds the respective array and vector elements from each tuple // argument and returns the results as a tuple. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, tuple_shape0, "x"); - auto y = builder.Parameter(1, tuple_shape1, "y"); - auto x_0 = builder.GetTupleElement(x, 0); - auto x_1 = builder.GetTupleElement(x, 1); - auto y_0 = builder.GetTupleElement(y, 0); - auto y_1 = builder.GetTupleElement(y, 1); - auto array_sum = builder.Add(x_0, y_1); - auto vector_diff = builder.Sub(x_1, y_0); - builder.Tuple({array_sum, vector_diff}); + auto x = Parameter(&builder, 0, tuple_shape0, "x"); + auto y = Parameter(&builder, 1, tuple_shape1, "y"); + auto x_0 = GetTupleElement(x, 0); + auto x_1 = GetTupleElement(x, 1); + auto y_0 = GetTupleElement(y, 0); + auto y_1 = GetTupleElement(y, 1); + auto array_sum = Add(x_0, y_1); + auto vector_diff = Sub(x_1, y_0); + Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_literal = Literal::MakeTuple( @@ -319,11 +315,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -338,15 +333,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { // Computation negates the array element and sums the two vector elements in // the nested tuple. The resulting array and vector are returned as a tuple. XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, nested_tuple_shape, "param"); - auto inner_tuple = builder.GetTupleElement(param, 0); - auto inner_array = builder.GetTupleElement(inner_tuple, 0); - auto inner_vector = builder.GetTupleElement(inner_tuple, 1); - auto outer_vector = builder.GetTupleElement(param, 1); - - auto negate_array = builder.Neg(inner_array); - auto vector_sum = builder.Add(inner_vector, outer_vector); - builder.Tuple({negate_array, vector_sum}); + auto param = Parameter(&builder, 0, nested_tuple_shape, "param"); + auto inner_tuple = GetTupleElement(param, 0); + auto inner_array = GetTupleElement(inner_tuple, 0); + auto inner_vector = GetTupleElement(inner_tuple, 1); + auto outer_vector = GetTupleElement(param, 1); + + auto negate_array = Neg(inner_array); + auto vector_sum = Add(inner_vector, outer_vector); + Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); auto arg_literal = Literal::MakeTuple( @@ -360,10 +355,10 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -376,10 +371,10 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { ShapeUtil::MakeTupleShape({array_shape, array_shape}); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); - auto element_0 = builder.GetTupleElement(param, 0); - auto element_1 = builder.GetTupleElement(param, 1); - builder.Tuple({builder.Neg(element_0), builder.Add(element_1, element_1)}); + auto param = Parameter(&builder, 0, tuple_shape, "param"); + auto element_0 = GetTupleElement(param, 0); + auto element_1 = GetTupleElement(param, 1); + Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); auto arg_literal = Literal::MakeTuple( @@ -389,18 +384,17 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, + LiteralSlice(*result_0_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, + LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); - LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, + LiteralSlice(*result_1_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, + LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -420,16 +414,15 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); + auto param = Parameter(&builder, 0, tuple_shape, "param"); // Add each element's tuple index value to every element. std::vector result_elements; for (int i = 0; i < kElementCount; ++i) { - auto element = builder.GetTupleElement(param, i); - result_elements.push_back( - builder.Add(element, builder.ConstantR0(i))); + auto element = GetTupleElement(param, i); + result_elements.push_back(Add(element, ConstantR0(&builder, i))); } - builder.Tuple(result_elements); + Tuple(&builder, result_elements); auto computation = builder.Build().ConsumeValueOrDie(); // Feed in a tuple where each two-element vector element is {tuple_index, @@ -447,8 +440,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), - error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -465,22 +457,22 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); + auto param = Parameter(&builder, 0, tuple_shape, "param"); // The computation increments each leaf value by an amount equal to the leaf's // ordinal position in a traversal of the tuple. std::vector result_elements; for (int i = 0; i < kFanout; ++i) { - auto outer_element = builder.GetTupleElement(param, i); + auto outer_element = GetTupleElement(param, i); std::vector inner_result_elements; for (int j = 0; j < kFanout; ++j) { - auto inner_element = builder.GetTupleElement(outer_element, j); - inner_result_elements.push_back(builder.Add( - inner_element, builder.ConstantR0(i * kFanout + j))); + auto inner_element = GetTupleElement(outer_element, j); + inner_result_elements.push_back( + Add(inner_element, ConstantR0(&builder, i * kFanout + j))); } - result_elements.push_back(builder.Tuple(inner_result_elements)); + result_elements.push_back(Tuple(&builder, inner_result_elements)); } - builder.Tuple(result_elements); + Tuple(&builder, result_elements); auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. @@ -520,14 +512,14 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { } XlaBuilder builder(TestName()); - auto element = builder.Parameter(0, shape, "param"); + auto element = Parameter(&builder, 0, shape, "param"); for (int i = 0; i < kTupleDepth; ++i) { - element = builder.GetTupleElement(element, 0); + element = GetTupleElement(element, 0); } - auto output = builder.Add(element, builder.ConstantR0(42.0)); + auto output = Add(element, ConstantR0(&builder, 42.0)); for (int i = 0; i < kTupleDepth; ++i) { - output = builder.Tuple({output}); + output = Tuple(&builder, {output}); } auto computation = builder.Build().ConsumeValueOrDie(); @@ -547,16 +539,16 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } - LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralSlice(*result_literal, index)); + LiteralTestUtil::ExpectR0Equal(165.0, + LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { // Test passing in an invalid number of arguments. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y"); + Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); @@ -571,8 +563,8 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { // Test passing in an argument with the wrong shape. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - builder.Neg(x); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); @@ -588,8 +580,8 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { // Test passing in an invalid result layout parameter. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - builder.Neg(x); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); @@ -611,7 +603,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { // Try to run a trivial computation on every device on the system. If a // specific device is not supported, check that the right error is returned. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { if (!local_client_->device_ordinal_supported(d)) { @@ -638,7 +630,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { // Try running computations on devices with device ordinal values which do not // exist. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); auto execute_status = @@ -655,7 +647,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // Run a computation on a specific stream on each device on the system. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { @@ -691,7 +683,7 @@ XLA_TEST_F(LocalClientExecuteTest, wrong_stream.Init(); XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_stream(&wrong_stream)); @@ -708,7 +700,7 @@ XLA_TEST_F(LocalClientExecuteTest, TestAllocator allocator(wrong_platform); XlaBuilder builder(TestName()); - auto y = builder.ConstantR0(123.0f); + ConstantR0(&builder, 123.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), @@ -721,7 +713,7 @@ XLA_TEST_F(LocalClientExecuteTest, XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { // Try to run a computation on a stream that has not been initialized. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); se::StreamExecutor* executor = @@ -744,26 +736,26 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); + Select(ConstantR0(&builder, false), tuple12, tuple21); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); + LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, + LiteralSlice(*tuple_literal, {0})); + LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, + LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); Shape argument_layout = ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}); @@ -779,6 +771,10 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); + ASSERT_IS_OK(local_client_->mutable_backend() + ->BorrowStream(0) + .ValueOrDie() + ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); @@ -848,15 +844,40 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { Literal::CreateR0(123456789000LL).get()})); } +XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { + XlaBuilder builder(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {3}); + auto in = Infeed(&builder, shape); + auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); + Add(in, constant); + + std::unique_ptr result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + result = ShapedBufferToLiteral(ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), /*arguments=*/{})); + })); + + ASSERT_IS_OK(local_client_->TransferToInfeedLocal( + *Literal::CreateR1({-5.0, 123.0, 42.0}), + local_client_->default_device_ordinal())); + + // Join the thread. + thread.reset(); + + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); +} + // TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. // 2017-10-18. XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); - auto in = builder.Infeed(shape); - auto constant = builder.ConstantR1({1.0f, 2.0f, 3.0f}); - auto sum = builder.Add(in, constant); - builder.Outfeed(sum, shape, /*outfeed_config=*/""); + auto in = Infeed(&builder, shape); + auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); + auto sum = Add(in, constant); + Outfeed(sum, shape, /*outfeed_config=*/""); std::unique_ptr thread( tensorflow::Env::Default()->StartThread( @@ -891,8 +912,8 @@ void BM_LocalClientOverhead(int num_iters) { // Use a tiny add operation as the computation. XlaBuilder builder("Add"); auto shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto x = builder.Parameter(0, shape, "x"); - builder.Add(x, x); + auto x = Parameter(&builder, 0, shape, "x"); + Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); auto buffer = @@ -900,8 +921,10 @@ void BM_LocalClientOverhead(int num_iters) { ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer)); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, + buffer)); const int kWarmups = 2; @@ -911,11 +934,8 @@ void BM_LocalClientOverhead(int num_iters) { std::unique_ptr executable = executable_status.ConsumeValueOrDie(); - se::Stream stream(executors[client->default_device_ordinal()]); - stream.Init(); - ExecutableRunOptions run_options; - run_options.set_allocator(&allocator).set_stream(&stream); + run_options.set_allocator(&allocator).set_stream(stream.get()); for (int i = 0; i < kWarmups; ++i) { auto result = executable->Run({&buffer}, run_options); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 88797a7d0a7d0567b3a380c5fb1ad0c0ee875587..c31ba0e713a45d18b60bfdb9a47545cf34220333 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -189,7 +189,19 @@ StatusOr LocalClientTestBase::ExecuteLocally( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, local_client_->Compile(computation, argument_layouts, build_options)); - return executable->Run(arguments, run_options); + TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options)); + + auto device_ordinal = + build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal(); + auto* stream = run_options.stream(); + if (!stream) { + stream = local_client_->mutable_backend() + ->BorrowStream(device_ordinal) + .ValueOrDie() + .get(); + } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + return std::move(ret); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index c0c02e584c2348f64a9d7d0800038f5ca67a2171..cdf70ee4185be2ecd9dcb2d21fbd98c2ab6cc0ad 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -30,8 +30,8 @@ class LogTest : public ClientLibraryTestBase {}; XLA_TEST_F(LogTest, LogZeroValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR3FromArray3D(Array3D(3, 0, 0)); - builder.Log(x); + auto x = ConstantR3FromArray3D(&builder, Array3D(3, 0, 0)); + Log(x); ComputeAndCompareR3(&builder, Array3D(3, 0, 0), {}, ErrorSpec(0.0001)); @@ -42,8 +42,8 @@ TEST_F(LogTest, LogTenValues) { 5.0, 6.0, -7.0, -8.0, 9.0}; XlaBuilder builder(TestName()); - auto x = builder.ConstantR1(input); - builder.Log(x); + auto x = ConstantR1(&builder, input); + Log(x); std::vector expected; expected.reserve(input.size()); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 3975e9125703ee081d4e84fa8bd27fcbe483ac34..1b3bc9d5040e1382f534e00ea2679ebbd48ceb59 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -52,9 +52,9 @@ class MapTest : public ClientLibraryTestBase { // 1.0f ---------/ XlaComputation CreateAdderToOne() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = mapped_builder.ConstantR0(1.0); - mapped_builder.Add(x, one); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&mapped_builder, 1.0); + Add(x, one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -62,9 +62,9 @@ class MapTest : public ClientLibraryTestBase { XlaComputation CreateMax() { XlaBuilder b(TestName()); - auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - b.Max(lhs, rhs); + auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Max(lhs, rhs); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -75,8 +75,8 @@ class MapTest : public ClientLibraryTestBase { template XlaComputation CreateScalarOne() { XlaBuilder mapped_builder("scalar_one"); - (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - mapped_builder.ConstantR0(1); + (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + ConstantR0(&mapped_builder, 1); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -89,9 +89,9 @@ class MapTest : public ClientLibraryTestBase { // 2.0f ---------/ XlaComputation CreateMulByTwo() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto two = mapped_builder.ConstantR0(2.0); - mapped_builder.Mul(x, two); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto two = ConstantR0(&mapped_builder, 2.0); + Mul(x, two); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -107,10 +107,10 @@ class MapTest : public ClientLibraryTestBase { // 1.0f ---------/ XlaComputation CreateAdderToOneTimesItself() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = mapped_builder.ConstantR0(1.0); - auto adder_to_one = mapped_builder.Add(x, one); - mapped_builder.Mul(x, adder_to_one); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&mapped_builder, 1.0); + auto adder_to_one = Add(x, one); + Mul(x, adder_to_one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -125,10 +125,10 @@ class MapTest : public ClientLibraryTestBase { XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation, float n) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto map = builder.Map({x}, embedded_computation, {}); - auto constant_n = builder.ConstantR0(n); - builder.Add(map, constant_n); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto map = Map(&builder, {x}, embedded_computation, {}); + auto constant_n = ConstantR0(&builder, n); + Add(map, constant_n); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -138,9 +138,9 @@ class MapTest : public ClientLibraryTestBase { // defined by (x, y) -> x > y. XlaComputation CreateGt() { XlaBuilder b("Gt"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - b.Gt(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Gt(x, y); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -155,11 +155,11 @@ class MapTest : public ClientLibraryTestBase { // z {R0F32} ---------------/ XlaComputation CreateTernaryAdder() { XlaBuilder mapped_builder("TernaryAdder"); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z"); - auto xy = mapped_builder.Add(x, y); - mapped_builder.Add(xy, z); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z"); + auto xy = Add(x, y); + Add(xy, z); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -173,8 +173,8 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, ErrorSpec(0.01f)); @@ -187,8 +187,8 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -202,8 +202,8 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -216,8 +216,8 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateScalarOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -229,8 +229,8 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateScalarOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -243,8 +243,8 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOneTimesItself(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, @@ -259,9 +259,9 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - builder.Map({map1}, CreateMulByTwo(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); + Map(&builder, {map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -276,9 +276,9 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - builder.Map({map1}, CreateMulByTwo(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); + Map(&builder, {map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -292,8 +292,8 @@ TEST_F(MapTest, MapEachElemPlusOneR2) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0, 1}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); @@ -319,10 +319,10 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed3 = CreateMapPlusN(embed1, 4.0); XlaBuilder embed4_builder("embed4"); - auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); - auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); - auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); - embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); + auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x"); + auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {}); + auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {}); + Add(embed4_map_lhs, embed4_map_rhs); auto embed4_status = embed4_builder.Build(); ASSERT_IS_OK(embed4_status.status()); auto embed4 = embed4_status.ConsumeValueOrDie(); @@ -330,11 +330,11 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed5 = CreateMapPlusN(embed2, 6.0); XlaBuilder builder(TestName()); - auto constant_42 = builder.ConstantR0(42.0); - auto constant_7 = builder.ConstantR0(7.0); - auto map_42 = builder.Map({constant_42}, embed5, {}); - auto map_7 = builder.Map({constant_7}, embed4, {}); - builder.Add(map_42, map_7); + auto constant_42 = ConstantR0(&builder, 42.0); + auto constant_7 = ConstantR0(&builder, 7.0); + auto map_42 = Map(&builder, {constant_42}, embed5, {}); + auto map_7 = Map(&builder, {constant_7}, embed4, {}); + Add(map_42, map_7); ComputeAndCompareR0(&builder, 73.0, {}, ErrorSpec(0.01f)); } @@ -351,9 +351,10 @@ TEST_F(MapTest, MapBinaryAdder) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), + {0}); ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, {param0_data.get(), param1_data.get()}, @@ -374,10 +375,10 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), - {0, 1}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1}); Array2D expected(2, 2); expected(0, 0) = 11; @@ -400,10 +401,10 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), - {0, 1, 2}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1, 2}); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {param0_data.get(), param1_data.get()}); @@ -425,10 +426,10 @@ TEST_F(MapTest, MapTernaryAdder) { std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); - builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, @@ -440,7 +441,8 @@ TEST_F(MapTest, MapGt) { // Maps (x,y) -> x > y onto two R1F32 vectors. XlaBuilder b(TestName()); auto gt = CreateGt(); - b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt, {0}); + Map(&b, {ConstantR1(&b, {1, 20}), ConstantR1(&b, {10, 2})}, gt, + {0}); ComputeAndCompareR1(&b, {false, true}, {}); } @@ -449,15 +451,15 @@ TEST_F(MapTest, NestedBinaryMap) { { // max_with_square(x) = do max(x, x^2) via a map. XlaBuilder b("max_with_square"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - b.Map({x, b.Mul(x, x)}, CreateMax(), {}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + Map(&b, {x, Mul(x, x)}, CreateMax(), {}); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); max_with_square = computation_status.ConsumeValueOrDie(); } XlaBuilder b(TestName()); - auto input = b.ConstantR1({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); - b.Map({input}, max_with_square, {0}); + auto input = ConstantR1(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); + Map(&b, {input}, max_with_square, {0}); ComputeAndCompareR1(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); } @@ -468,9 +470,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("ErrorAdd"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y"); - sub_builder->Add(x, y); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y"); + Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = @@ -482,9 +484,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, error_add, {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); @@ -506,9 +508,9 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - sub_builder->Pow(x, y); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y"); + Pow(x, y); auto power = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = Literal::CreateR0(2.0f); @@ -518,9 +520,9 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, power, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, {param0_data.get(), param1_data.get()}, @@ -533,9 +535,9 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - sub_builder->Sub(y, x); // note that this is y - x, not x - y + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y"); + Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = Literal::CreateR0(2.0f); @@ -545,9 +547,9 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, sub_opposite, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); @@ -559,16 +561,16 @@ TEST_F(MapTestWithFullOpt, MapSquare) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - sub_builder->Mul(x, x); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + Mul(x, x); auto square = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = Literal::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param0}, square, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, ErrorSpec(0.01f)); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 27fd36e06acdc589f3a84ad561164e4a33b93506..17b1807f44a457786906afc15d8d410f6cf2d4cd 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -56,11 +56,11 @@ TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { using T = TypeParam; XlaBuilder builder("exp_2x2"); - auto data = builder.ConstantR2FromArray2D({ - {1.0f, 0.0f}, // row 0 - {-1.0f, 0.5f}, // row 1 - }); - builder.Exp(data); + auto data = ConstantR2FromArray2D(&builder, { + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 + }); + Exp(data); std::unique_ptr expected = Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 @@ -76,20 +76,20 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { // add_half(x) = x + 0.5 XlaBuilder builder("add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({}), "x_value"); - auto half = builder.ConstantR0(static_cast(0.5)); - builder.Add(x_value, half); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({}), "x_value"); + auto half = ConstantR0(&builder, static_cast(0.5)); + Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); } XlaBuilder builder("map_2x2"); - auto data = builder.ConstantR2FromArray2D({ - {1.0f, 0.0f}, // row 0 - {-1.0f, 0.5f}, // row 1 - }); - auto map = builder.Map({data}, add_half, {0, 1}); + auto data = ConstantR2FromArray2D(&builder, { + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 + }); + Map(&builder, {data}, add_half, {0, 1}); std::unique_ptr expected = Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 @@ -100,15 +100,15 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { using T = TypeParam; XlaBuilder builder("max_2x2"); - auto lhs = builder.ConstantR2FromArray2D({ - {7.0f, 2.0f}, // row 0 - {3.0f, -4.0f}, // row 1 - }); - auto rhs = builder.ConstantR2FromArray2D({ - {5.0f, 6.0f}, // row 0 - {1.0f, -8.0f}, // row 1 - }); - auto max = builder.Max(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, { + {7.0f, 2.0f}, // row 0 + {3.0f, -4.0f}, // row 1 + }); + auto rhs = ConstantR2FromArray2D(&builder, { + {5.0f, 6.0f}, // row 0 + {1.0f, -8.0f}, // row 1 + }); + Max(lhs, rhs); std::unique_ptr expected = Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 @@ -137,9 +137,9 @@ class TestLinspaceMaxParametric XlaBuilder builder( tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - auto max = builder.Max(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Max(lhs, rhs); Array2D expected(rows, cols); for (int row = 0; row < rows; ++row) { @@ -208,23 +208,23 @@ class MatOpsDotAddTest rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); auto lhs_mat_arg = lhs_arg; if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); + lhs_mat_arg = Transpose(lhs_mat_arg, {1, 0}); } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); + auto rhs_arg = Parameter(&builder, 1, rhs_shape, "rhs"); + auto result = Dot(lhs_mat_arg, rhs_arg); Array2D expected; if (add_lhs) { - result = builder.Add(result, lhs_arg); + result = Add(result, lhs_arg); if (transpose) { expected = Array2D({{47.0f, 52.0f}, {71.0f, 78.0f}}); } else { expected = Array2D({{35.0f, 39.0f}, {81.0f, 89.0f}}); } } else { - result = builder.Add(result, rhs_arg); + result = Add(result, rhs_arg); if (transpose) { expected = Array2D({{56.0f, 61.0f}, {80.0f, 87.0f}}); } else { diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 0791a71aacf7614286fe964623a3172a174d4722..e576f000ef23e761d6fa818457eec2144d4bcb00 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -33,9 +33,10 @@ class SliceTest : public ClientLibraryTestBase {}; XLA_TEST_F(SliceTest, Slice2D) { XlaBuilder builder("slice_2d"); - auto original = builder.ConstantR2( + auto original = ConstantR2( + &builder, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}}); - builder.Slice(original, {2, 1}, {4, 3}, {1, 1}); + Slice(original, {2, 1}, {4, 3}, {1, 1}); Array2D expected({{8.0f, 9.0f}, {11.0f, 12.0f}}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); @@ -45,8 +46,8 @@ XLA_TEST_F(SliceTest, Slice3D) { XlaBuilder builder("slice_3d"); Array3D array_3d( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); - auto original = builder.ConstantR3FromArray3D(array_3d); - builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, array_3d); + Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1}); Array3D expected_3d({{{2.0f}}, {{6.0f}}}); ComputeAndCompareR3(&builder, expected_3d, {}, ErrorSpec(0.000001)); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 41f723edf1ff3518686231f31b61b64291b1f6bf..6597748c8d1f45391799dbe384a5afc0284de2dd 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -204,10 +204,10 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { Literal::CreateR0(1.0)), Literal::MakeTupleOwned(Literal::CreateR0(3.0), Literal::CreateR0(4))); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); + *Literal::MakeTupleOwned(Literal::CreateR0(42)), *result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -233,10 +233,9 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -267,10 +266,9 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR1({1.0, 2.0, 3.0}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); } const char* const kScalarOps = R"( @@ -311,12 +309,12 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})))); + Literal::CreateR2({{5, 16}, {36, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -341,12 +339,12 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( - Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR2({{25, 36}, {49, 64}})))); + *Literal::MakeTupleOwned(Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR2({{25, 36}, {49, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -372,12 +370,13 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), - Literal::CreateR1({36, 64}), - Literal::CreateR1({66, 138})))); + *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), + Literal::CreateR1({36, 64}), + Literal::CreateR1({66, 138})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -403,14 +402,14 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})))); + Literal::CreateR2({{5, 16}, {36, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -436,14 +435,14 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( Literal::CreateR2({{6, 8}, {10, 12}}), Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR2({{25, 36}, {49, 64}})))); + Literal::CreateR2({{25, 36}, {49, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -455,7 +454,8 @@ XLA_TEST_F(MultiOutputFusionTest, r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1) + b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={} + mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1) ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) tuple(r1, mul, mul2) } @@ -469,15 +469,15 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( Literal::CreateR1({14, 22}), Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), Literal::CreateR3( - {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})))); + {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -505,13 +505,52 @@ XLA_TEST_F(MultiOutputFusionTest, auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = Literal::CreateR0(5); auto init2 = Literal::CreateR0(6); - TF_ASSERT_OK_AND_ASSIGN( - auto result, - Execute(std::move(module), {param.get(), init1.get(), init2.get()})); + std::unique_ptr result = ExecuteNoHloPasses( + std::move(module), {param.get(), init1.get(), init2.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR2({{167, 172}, {176, 180}}), + Literal::CreateR2({{6, 6}, {6, 8}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { + p0 = f16[2,2,2]{2,1,0} parameter(0) + convert = f32[2,2,2]{2,1,0} convert(p0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(convert, convert) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) + tuple(r1, r2, p0) + } + + ENTRY reduce { + p = f16[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3( + {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( - Literal::CreateR2({{167, 172}, {176, 180}}), - Literal::CreateR2({{6, 6}, {6, 8}})))); + *Literal::MakeTupleOwned( + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}}), + Literal::CreateR3({{{Eigen::half(1), Eigen::half(2)}, + {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, + {Eigen::half(7), Eigen::half(8)}}})), + *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index ce295b832d79e4f00656f2893c2ba1162693dd73..2e5081bbcb64ea9416c5a9731dba43891ecceedf 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - b.Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - b.Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,16 +123,16 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - b.Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); - b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); + Pad(AddParam(Array4D(2, 0, 3, 2), &b), + AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); } @@ -147,8 +147,8 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - b.Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), - r4_padding_on_dim0_dim1_); + Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(1.5); @@ -166,8 +166,8 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - b.Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), - r4_padding_on_dim0_dim1_); + Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); expected->Fill(pad_value); @@ -208,8 +208,8 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(AddParam(*input, &b), - AddParam(*Literal::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -254,8 +254,8 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(AddParam(*input, &b), - AddParam(*Literal::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -275,8 +275,8 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { }); input->FillWithYX(input_xy); - b.Pad(AddParam(*input, &b), b.ConstantR0(35), - r4_padding_on_dim0_dim1_); + Pad(AddParam(*input, &b), ConstantR0(&b, 35), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(35); @@ -294,16 +294,16 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { // Since bool is currently not well supported, use Broadcast operation to // create the operand for Pad. - auto input = b.Broadcast(b.ConstantR0(true), {1, 1, 3, 2}); + auto input = Broadcast(ConstantR0(&b, true), {1, 1, 3, 2}); auto padded = - b.Pad(input, b.ConstantR0(false), r4_padding_on_dim0_dim1_); + Pad(input, ConstantR0(&b, false), r4_padding_on_dim0_dim1_); // For the same reason, use Select to convert boolean values to int32. auto zeros = MakeUnique>(2, 3, 3, 2); auto ones = MakeUnique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); - b.Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); + Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(0); @@ -329,7 +329,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -351,7 +351,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -376,7 +376,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -403,7 +403,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -430,7 +430,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -446,12 +446,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 838f1b4e2f0f0e0871ec717bdeefcbbc653397e3..2620063aa492902a705690d28d8124d16184d635 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); ComputeAndCompareR0(&builder, 3.14159f, {param0_data.get()}, ErrorSpec(0.0001f)); @@ -58,7 +58,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -71,7 +71,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); ComputeAndCompareR1(&builder, {3.14f, -100.25f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -84,8 +84,9 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter( - 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), "param0"); + Parameter(&builder, 0, + ShapeUtil::MakeShape(U8, {static_cast(str.size())}), + "param0"); ComputeAndCompareR1U8(&builder, str, {param0_data.get()}); } @@ -97,7 +98,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); ComputeAndCompareR2(&builder, Array2D(3, 0), {param0_data.get()}, ErrorSpec(0.01f)); @@ -110,7 +111,7 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); Array2D expected_array( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); @@ -124,25 +125,25 @@ XLA_TEST_F(ParamsTest, TwoParameters) { std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); // Use both parameters // // {1, 2} + {10, 20} = {11, 22} - auto sum = builder.Add(param0, param1); - sum = builder.Add(param0, param1); + auto sum = Add(param0, param1); + sum = Add(param0, param1); // Use only the second parameter again, to show that it can be used // twice and to make the computation asymmetric in the two // parameters to test that the parameters are not swapped. // // {11, 22} * {10, 20} = {110, 440} - auto prod = builder.Mul(sum, param1); + Mul(sum, param1); ComputeAndCompareR1(&builder, {110, 440}, {param0_data.get(), param1_data.get()}, @@ -157,7 +158,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { client_->TransferToServer(*literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); + Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); ASSERT_NE(computation_status.status(), Status::OK()); @@ -169,12 +170,12 @@ XLA_TEST_F(ParamsTest, UnusedParameter) { std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + Parameter(&builder, 0, literal0->shape(), "param0"); std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + Parameter(&builder, 1, literal1->shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -194,14 +195,14 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); - auto param2 = builder.Parameter(2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); // This add is unused. - builder.Add(param1, param2); + Add(param1, param2); - builder.Neg(param0); + Neg(param0); ComputeAndCompareR1( &builder, {-1, -2}, @@ -215,7 +216,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector init_value = {{0, 1}}; init_value.resize(size); - XlaOp sum_handle = builder.ConstantR1(init_value); + XlaOp sum_handle = ConstantR1(&builder, init_value); std::vector sum = {{0, 1}}; sum.resize(size); @@ -233,8 +234,8 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::unique_ptr literal = Literal::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); - sum_handle = builder.Add(sum_handle, param); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + sum_handle = Add(sum_handle, param); } std::vector param_data; @@ -260,7 +261,7 @@ XLA_TEST_F(ParamsTest, XlaBuilder builder(TestName()); std::vector> param_data_owner; - XlaOp sum_handle = builder.ConstantR0(0.0f); + XlaOp sum_handle = ConstantR0(&builder, 0.0f); float target = 0.0; constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { @@ -268,8 +269,8 @@ XLA_TEST_F(ParamsTest, std::unique_ptr literal = Literal::CreateR0(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); - sum_handle = builder.Add(sum_handle, param); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + sum_handle = Add(sum_handle, param); } std::vector param_data; @@ -291,7 +292,7 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( XlaBuilder builder(TestName()); std::vector> param_data_owner; - XlaOp sum_handle = builder.ConstantR1({0, 0}); + XlaOp sum_handle = ConstantR1(&builder, {0, 0}); int32 target = 0; constexpr int kParamCount = 3000; std::vector params; @@ -300,17 +301,17 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); params.push_back(param); - sum_handle = builder.Add(sum_handle, param); + sum_handle = Add(sum_handle, param); } std::vector outputs; for (int i = 0; i < kParamCount; ++i) { - outputs.push_back(builder.Add(params[i], sum_handle)); + outputs.push_back(Add(params[i], sum_handle)); } - builder.Tuple(outputs); + Tuple(&builder, outputs); std::vector param_data; param_data.reserve(param_data_owner.size()); @@ -356,7 +357,7 @@ XLA_TEST_F(ParamsTest, std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); params.push_back(param); parameter_shapes.push_back(literal->shape()); } @@ -367,11 +368,11 @@ XLA_TEST_F(ParamsTest, param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); XlaOp bool_param = - builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); params.push_back(bool_param); parameter_shapes.push_back(bool_literal->shape()); - auto init = builder.Tuple(params); + auto init = Tuple(&builder, params); // Create a computation for the condition: while(bool_param). Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); @@ -379,8 +380,8 @@ XLA_TEST_F(ParamsTest, { XlaBuilder builder("condition"); auto condition_parameter = - builder.Parameter(0, while_shape, "condition_parameter"); - builder.GetTupleElement(condition_parameter, kParamCount); + Parameter(&builder, 0, while_shape, "condition_parameter"); + GetTupleElement(condition_parameter, kParamCount); condition = builder.Build().ConsumeValueOrDie(); } @@ -389,27 +390,27 @@ XLA_TEST_F(ParamsTest, XlaComputation body; { XlaBuilder builder("body"); - auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + auto body_parameter = Parameter(&builder, 0, while_shape, "body_parameter"); std::vector updates; for (int i = 0; i < kParamCount; ++i) { - auto add = builder.Add(builder.GetTupleElement(body_parameter, i), - builder.ConstantR1({1, 1})); + auto add = Add(GetTupleElement(body_parameter, i), + ConstantR1(&builder, {1, 1})); updates.push_back(add); } // Add bool parameter. - updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + updates.push_back(GetTupleElement(body_parameter, kParamCount)); - builder.Tuple(updates); + Tuple(&builder, updates); body = builder.Build().ConsumeValueOrDie(); } - auto loop = builder.While(condition, body, init); + auto loop = While(condition, body, init); std::vector outputs; for (int i = 0; i < kParamCount; ++i) { - outputs.push_back(builder.GetTupleElement(loop, i)); + outputs.push_back(GetTupleElement(loop, i)); } - builder.Tuple(outputs); + Tuple(&builder, outputs); std::vector param_data; param_data.reserve(param_data_owner.size()); @@ -433,10 +434,10 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3}); - auto input = builder.Parameter(0, tuple_shape, "input"); - auto lhs = builder.GetTupleElement(input, 0); - auto rhs = builder.GetTupleElement(input, 1); - builder.Add(lhs, rhs); + auto input = Parameter(&builder, 0, tuple_shape, "input"); + auto lhs = GetTupleElement(input, 0); + auto rhs = GetTupleElement(input, 1); + Add(lhs, rhs); std::unique_ptr data = client_ @@ -457,7 +458,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - builder.Parameter(0, literal->shape(), "input"); + Parameter(&builder, 0, literal->shape(), "input"); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -469,7 +470,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - builder.Parameter(0, literal->shape(), "input"); + Parameter(&builder, 0, literal->shape(), "input"); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -478,7 +479,8 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { std::unique_ptr literal = Literal::CreateR2({ - {1, 3}, {2, 4}, + {1, 3}, + {2, 4}, }); const Shape original = literal->shape(); { @@ -494,9 +496,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { } // Use the original shape in building the computation. XlaBuilder builder(TestName()); - auto input = builder.Parameter(0, original, "input"); + auto input = Parameter(&builder, 0, original, "input"); // Use the slice operator to get an off-diagonal element. - builder.Slice(input, {0, 1}, {1, 2}, {1, 1}); + Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 77159efb26f3b7dd4918f24305f7269a2d6ff647..5c351b2d113709105244de4aafa49d7cc535ced1 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -29,64 +29,63 @@ namespace { class PredTest : public ClientLibraryTestBase { protected: - void TestCompare( - bool lhs, bool rhs, bool expected, - XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(bool lhs, bool rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; TEST_F(PredTest, ConstantR0PredTrue) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(true); + ConstantR0(&builder, true); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, ConstantR0PredFalse) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(false); + ConstantR0(&builder, false); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, ConstantR0PredCompareEq) { - TestCompare(true, false, false, &XlaBuilder::Eq); + TestCompare(true, false, false, &Eq); } TEST_F(PredTest, ConstantR0PredCompareNe) { - TestCompare(true, false, true, &XlaBuilder::Ne); + TestCompare(true, false, true, &Ne); } TEST_F(PredTest, ConstantR0PredCompareLe) { - TestCompare(true, false, false, &XlaBuilder::Le); + TestCompare(true, false, false, &Le); } TEST_F(PredTest, ConstantR0PredCompareLt) { - TestCompare(true, false, false, &XlaBuilder::Lt); + TestCompare(true, false, false, &Lt); } TEST_F(PredTest, ConstantR0PredCompareGe) { - TestCompare(true, false, true, &XlaBuilder::Ge); + TestCompare(true, false, true, &Ge); } TEST_F(PredTest, ConstantR0PredCompareGt) { - TestCompare(true, false, true, &XlaBuilder::Gt); + TestCompare(true, false, true, &Gt); } TEST_F(PredTest, ConstantR1Pred) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, false, true}); + ConstantR1(&builder, {true, false, false, true}); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } TEST_F(PredTest, ConstantR2Pred) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{false, true, true}, {true, false, false}}); + ConstantR2(&builder, {{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { { 011 }, { 100 } @@ -96,44 +95,44 @@ TEST_F(PredTest, ConstantR2Pred) { TEST_F(PredTest, AnyR1True) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {true, false}); + Any(a); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR1False) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {false, false}); + Any(a); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR1VacuouslyFalse) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {}); + Any(a); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR2True) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({ - {false, false, false}, - {false, false, false}, - {false, false, true}, - }); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR2(&builder, { + {false, false, false}, + {false, false, false}, + {false, false, true}, + }); + Any(a); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR2False) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({ - {false, false, false}, - {false, false, false}, - {false, false, false}, - }); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR2(&builder, { + {false, false, false}, + {false, false, false}, + {false, false, false}, + }); + Any(a); ComputeAndCompareR0(&builder, false, {}); } diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 1a2de6937c3e134852a730f62f7b56417cf49b28..8e163e885d0d6315341c213577a3beb0180b679a 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -53,8 +53,8 @@ template std::unique_ptr PrngTest::UniformTest( T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { XlaBuilder builder(TestName()); - builder.RngUniform( - builder.ConstantR0(a), builder.ConstantR0(b), + RngUniform( + ConstantR0(&builder, a), ConstantR0(&builder, b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); SetSeed(seed); @@ -141,9 +141,9 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, int32 sample_size = range_size * expected_count; XlaBuilder builder(TestName()); - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(range_size), - ShapeUtil::MakeShape(S32, {sample_size})); + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, range_size), + ShapeUtil::MakeShape(S32, {sample_size})); SetSeed(seed); auto actual = @@ -184,9 +184,10 @@ XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. auto build_sum_rng = [this](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); - auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input"); - b->Add(x, b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), - ShapeUtil::MakeShape(F32, {}))); + auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input"); + Add(x, + RngUniform(ConstantR0(b.get(), 0), ConstantR0(b.get(), 1), + ShapeUtil::MakeShape(F32, {}))); return b->BuildAndNoteError(); }; @@ -196,9 +197,9 @@ XLA_TEST_F(PrngTest, MapUsingRng) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, client_->TransferToServer(*param0_literal)); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); auto fn = build_sum_rng(builder); - builder.Map({param0}, fn, {0}); + Map(&builder, {param0}, fn, {0}); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -226,9 +227,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { // Build a U[0,1) computation. auto build_computation = [this]() { XlaBuilder builder(TestName()); - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(1), - ShapeUtil::MakeShape(F32, {10})); + RngUniform(ConstantR0(&builder, 0), ConstantR0(&builder, 1), + ShapeUtil::MakeShape(F32, {10})); return builder.Build(); }; @@ -282,8 +282,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { XLA_TEST_F(PrngTest, TenValuesN01) { XlaBuilder builder(TestName()); - builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), - ShapeUtil::MakeShape(F32, {10})); + RngNormal(ConstantR0(&builder, 0), ConstantR0(&builder, 1), + ShapeUtil::MakeShape(F32, {10})); SetSeed(42); ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); @@ -294,9 +294,9 @@ XLA_TEST_F(PrngTest, RngUniformCrash) { XlaBuilder builder(TestName()); // This used to crash XLA during LLVM IR generation for CPUs. - auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(1000 * 1000), - ShapeUtil::MakeShape(S32, {})); + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 1000 * 1000), + ShapeUtil::MakeShape(S32, {})); SetSeed(0); ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index f95e75648343aa88bd7c39de4ee9f387f2b60506..526a38e8d1dbed9cdd4a31bfbec49bc5c6bb174b 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -31,8 +31,8 @@ class QueryInferredShapeTest : public ClientLibraryTestBase {}; TEST_F(QueryInferredShapeTest, OnePlusOneShape) { XlaBuilder builder("one_plus_one"); - auto one = builder.ConstantR0(1.0); - auto result = builder.Add(one, one); + auto one = ConstantR0(&builder, 1.0); + auto result = Add(one, one); StatusOr shape_status = builder.GetShape(result); ASSERT_IS_OK(shape_status.status()); auto shape = shape_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index b311785449f1774c3bc1e4d7ad35c2866e3b4061..4c1aa121067eed465c6128ea7a34e0284f7af43e 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -233,9 +233,9 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { std::unique_ptr a_literal = Literal::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); - builder.ReducePrecision(a, exponent_bits, mantissa_bits); + ReducePrecision(a, exponent_bits, mantissa_bits); ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); } @@ -256,15 +256,15 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // Abs doesn't affect resolution. - auto abs = builder.Abs(a); + auto abs = Abs(a); // Near 1.0, Log(x) approximates x - 1; this lets us confirm that the // reduce-precision operation showed up in the correct place in the // graph. - builder.Log(abs); + Log(abs); // Insert precision-reduction after the Abs(x) operation, rounding that // result to exactly 1.0f. @@ -285,11 +285,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass after operation fusion, suffixing kAbs operations. This // should not see into the fusion nodes and thus should not affect the @@ -311,11 +311,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass after operation fusion, suffixing kFusion operations. auto reduce_precision_pass = execution_options_.mutable_debug_options() @@ -335,11 +335,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass suffixing fusion nodes containing kCos operations. This // should have no effect. @@ -360,11 +360,11 @@ XLA_TEST_F(ReducePrecisionInsertionTest, std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass suffixing fusion nodes containing kAbs operations. This // should see the kAbs operation within the above fusion node. diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index d671d40456a276a44b462f390c95aa4af301263a..c9f57cbb16729627a5e9ad3d49438295a286989e 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -89,9 +89,9 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); std::vector input_data(element_count); for (int64 i = 0; i < element_count; ++i) { @@ -118,20 +118,20 @@ class ReduceTest : public ClientLibraryTestBase { const int element_count = input_data.size(); XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); - auto input_par = builder.Parameter(0, input_shape, "input"); + auto input_par = Parameter(&builder, 0, input_shape, "input"); auto pred_values = - builder.Eq(input_par, builder.ConstantR1(element_count, 1)); + Eq(input_par, ConstantR1(&builder, element_count, 1)); XlaOp init_value; XlaComputation reduce; if (and_reduce) { - init_value = builder.ConstantR0(true); + init_value = ConstantR0(&builder, true); reduce = CreateScalarAndComputation(&builder); } else { - init_value = builder.ConstantR0(false); + init_value = ConstantR0(&builder, false); reduce = CreateScalarOrComputation(&builder); } - builder.Reduce(pred_values, init_value, reduce, - /*dimensions_to_reduce=*/{0}); + Reduce(pred_values, init_value, reduce, + /*dimensions_to_reduce=*/{0}); std::unique_ptr input_literal = Literal::CreateR1(input_data); std::unique_ptr input_global_data = @@ -156,21 +156,21 @@ class ReduceTest : public ClientLibraryTestBase { int64 major = 0) { XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto input_pred = builder.Eq(input, builder.ConstantR0(1)); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto input_pred = Eq(input, ConstantR0(&builder, 1)); XlaOp init_value; XlaComputation reduce_op; if (and_reduce) { - init_value = builder.ConstantR0(true); + init_value = ConstantR0(&builder, true); reduce_op = CreateScalarAndComputation(&builder); } else { - init_value = builder.ConstantR0(false); + init_value = ConstantR0(&builder, false); reduce_op = CreateScalarOrComputation(&builder); } - builder.Reduce(input_pred, init_value, reduce_op, - /*dimensions_to_reduce=*/{0}); + Reduce(input_pred, init_value, reduce_op, + /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(0, 1); @@ -202,9 +202,9 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -230,9 +230,9 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -287,10 +287,10 @@ class ReduceTest : public ClientLibraryTestBase { XlaComputation reduction_function = reduction_function_generator(&builder); const Shape input_shape = ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(initial_value); - builder.Reduce(input, zero, reduction_function, - /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, initial_value); + Reduce(input, zero, reduction_function, + /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillUnique(initial_value); @@ -442,10 +442,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); - builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Log(input); + Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -473,11 +473,11 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); - auto transpose = builder.Transpose(log_, {1, 0}); - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Log(input); + auto transpose = Transpose(log_, {1, 0}); + Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); @@ -505,10 +505,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); - XlaOp input = builder.Parameter(0, input_shape, "input"); - XlaOp zero = builder.ConstantR0(0.0); - XlaOp transpose = builder.Transpose(input, /*permutation=*/{1, 0, 2}); - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + XlaOp input = Parameter(&builder, 0, input_shape, "input"); + XlaOp zero = ConstantR0(&builder, 0.0); + XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); + Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, MakeFakeLiteral(input_shape)); @@ -522,11 +522,11 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Tanh(input); - auto reshape = builder.Reshape(log_, {rows, cols}); - builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Tanh(input); + auto reshape = Reshape(log_, {rows, cols}); + Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); @@ -568,9 +568,9 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) { XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); - auto scalar = builder.ConstantR0(42.0); - auto broadcasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broadcasted, builder.ConstantR0(0.0f), add, {0, 1}); + auto scalar = ConstantR0(&builder, 42.0); + auto broadcasted = Broadcast(scalar, {500, 500}); + Reduce(broadcasted, ConstantR0(&builder, 0.0f), add, {0, 1}); float expected = 42.0f * static_cast(500 * 500); ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -580,9 +580,9 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { XlaBuilder builder(TestName()); auto max = CreateScalarMaxComputation(F32, &builder); - auto scalar = builder.ConstantR0(42.0); - auto broadcasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broadcasted, builder.ConstantR0(0.0f), max, {0, 1}); + auto scalar = ConstantR0(&builder, 42.0); + auto broadcasted = Broadcast(scalar, {500, 500}); + Reduce(broadcasted, ConstantR0(&builder, 0.0f), max, {0, 1}); float expected = 42.0f; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -595,8 +595,8 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D input(300, 250); input.FillRandom(214.0f); auto input_literal = Literal::CreateR2FromArray2D(input); - builder.Reduce(builder.ConstantLiteral(*input_literal), - builder.ConstantR0(FLT_MIN), max, {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), + ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( [&](int64, int64, float* v) { input_max = std::max(input_max, *v); }); @@ -610,8 +610,8 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D input(150, 130); input.FillRandom(214.0f); auto input_literal = Literal::CreateR2FromArray2D(input); - builder.Reduce(builder.ConstantLiteral(*input_literal), - builder.ConstantR0(FLT_MAX), min, {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), + ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; input.Each( @@ -625,10 +625,9 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto min = CreateScalarMinComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); auto initial_value = - builder.ConstantR0(std::numeric_limits::max()); + ConstantR0(&builder, std::numeric_limits::max()); - builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, min, - {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -638,19 +637,18 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto max = CreateScalarMaxComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); auto initial_value = - builder.ConstantR0(std::numeric_limits::min()); + ConstantR0(&builder, std::numeric_limits::min()); - builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, max, - {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); std::vector expected = {6.f, 15.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -659,9 +657,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); ComputeAndCompareR0(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4)); } @@ -669,9 +667,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); std::vector expected = {5.f, 7.f, 9.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -679,9 +677,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1, 2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); std::vector expected = {21.f, 21.f, 21.f, 21.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -689,9 +687,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); std::vector expected = {20.f, 28.f, 36.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -699,9 +697,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1, 2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); float expected = 21.0f * 4.0; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -709,9 +707,9 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); // clang-format off Array2D expected({ @@ -724,9 +722,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); // clang-format off Array2D expected({ @@ -741,9 +739,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); // clang-format off Array2D expected({ @@ -827,10 +825,10 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { client_->TransferToServer(*input_literal).ConsumeValueOrDie(); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto sum = builder.Reduce(input_activations, builder.ConstantR0(0.0f), - add, GetParam().reduce_dims); + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, + GetParam().reduce_dims); auto expected = ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims, @@ -871,14 +869,14 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); - auto a = builder.ConstantR0(2.0f); - auto a2 = builder.Abs(a); + auto a = ConstantR0(&builder, 2.0f); + auto a2 = Abs(a); std::unique_ptr b_literal = Literal::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = builder.Parameter(0, b_literal->shape(), "b"); - auto max = builder.Reduce(b, a2, max_f32, {0}); + auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); } @@ -900,13 +898,13 @@ class ReduceInitializerTest : public ReduceTest { XlaComputation max_fn = CreateScalarMaxComputation( primitive_util::NativeToPrimitiveType(), &builder); - auto init = builder.ConstantR0(initializer); + auto init = ConstantR0(&builder, initializer); std::vector input_arr(num_elems, std::numeric_limits::lowest()); auto input_literal = Literal::CreateR1(input_arr); auto input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init, - max_fn, {0}); + Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, + max_fn, {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -939,15 +937,15 @@ XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { XLA_TEST_F(ReduceTest, ReduceIdentity) { XlaBuilder builder(TestName()); Shape single_float = ShapeUtil::MakeShape(F32, {}); - builder.Parameter(0, single_float, "lhs-unused"); - builder.Parameter(1, single_float, "rhs-used"); + Parameter(&builder, 0, single_float, "lhs-unused"); + Parameter(&builder, 1, single_float, "rhs-used"); auto computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); Shape operand_shape = ShapeUtil::MakeShape(F32, {1}); - builder.Reduce(builder.Parameter(0, operand_shape, "operand"), - builder.Parameter(1, single_float, "init"), - computation_status.ValueOrDie(), {0}); + Reduce(Parameter(&builder, 0, operand_shape, "operand"), + Parameter(&builder, 1, single_float, "init"), + computation_status.ValueOrDie(), {0}); float operand[] = {42.0f}; float init = 58.5f; diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 266760e8202fddc48792ac66dda334255e428808..741974480c6a862a7794aa6257f131a5893e963d 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -72,9 +72,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, Padding padding) { auto init = CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarAddComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } void ReduceWindowMax(const XlaOp& input, @@ -82,9 +82,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_strides, Padding padding) { auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarMaxComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + ReduceWindow(input, init, + CreateScalarMaxComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } void ReduceWindowMin(const XlaOp& input, @@ -92,9 +92,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_strides, Padding padding) { auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarMinComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } XlaBuilder builder_; @@ -106,10 +106,10 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); - builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(FloatType(), &builder_), - /*window_dimensions=*/{1, 2}, - /*window_strides=*/{1}, Padding::kValid); + ReduceWindow(input, init_value, + CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{1, 2}, + /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) << builder_.first_error(); ASSERT_THAT(builder_.first_error().error_message(), @@ -122,10 +122,9 @@ TEST_P(ReduceWindowTest, R0ReduceWindow) { CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); const auto init = CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarAddComputation(FloatType(), &builder_), - /*window_dimensions=*/{}, - /*window_strides=*/{}, Padding::kSame); + ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{}, + /*window_strides=*/{}, Padding::kSame); ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, ErrorSpec(0.00001)); } @@ -306,13 +305,13 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Padding padding = Padding::kValid; const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), - CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); + auto lhs = Parameter(b.get(), 0, scalar, "lhs"); + auto rhs = Parameter(b.get(), 1, scalar, "rhs"); + Min(Add(lhs, rhs), + CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow( + ReduceWindow( input, CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_), reduce_fn, @@ -542,7 +541,7 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - XlaOp input = builder_.Broadcast( + XlaOp input = Broadcast( CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; @@ -627,7 +626,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, auto computation = param.reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); - b.ReduceWindowWithGeneralPadding( + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, @@ -968,11 +967,11 @@ TEST_P(R3ReduceWindowTest, Add) { &b, ¶meter); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); auto expected = ReferenceUtil::ReduceWindow3DAdd( /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, @@ -1109,7 +1108,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindowWithGeneralPadding( + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, @@ -1306,7 +1305,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindowWithGeneralPadding( + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 36d763b0f7f4267ede076c0b25cfaf9654e96e0d..bebd814fa8b863428750dc12a93d1ef5ad7e6685 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -39,8 +39,8 @@ class ReplayTest : public ClientLibraryTestBase {}; TEST_F(ReplayTest, TwoPlusTwoReplay) { // Make 2+2 computation. XlaBuilder builder(TestName()); - auto two = builder.ConstantR0(2); - builder.Add(two, two); + auto two = ConstantR0(&builder, 2); + Add(two, two); XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. @@ -70,9 +70,9 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Make computation. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(S32, {}), "y"); + Add(x, y); XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. @@ -111,13 +111,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { // As above, but with map(+2) over some constant array. XlaBuilder plus_two_builder("plus two"); auto input = - plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input"); - plus_two_builder.Add(input, plus_two_builder.ConstantR0(2)); + Parameter(&plus_two_builder, 0, ShapeUtil::MakeShape(S32, {}), "input"); + Add(input, ConstantR0(&plus_two_builder, 2)); XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); XlaBuilder mapper_builder(TestName()); - auto original = mapper_builder.ConstantR1({1, 2, 3}); - mapper_builder.Map({original}, plus_two, {0}); + auto original = ConstantR1(&mapper_builder, {1, 2, 3}); + Map(&mapper_builder, {original}, plus_two, {0}); XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index da1b588ec41cef711412367e89b2a9b1029bca71..5812fe442b25da1b7e34494d00fe8025d29b2802 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -44,11 +44,11 @@ using ReshapeMotionTest = ClientLibraryTestBase; TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{2, 3, 5}, {7, 11, 13}}); - auto b = builder.ConstantR2({{17, 19}, {23, 29}, {31, 37}}); - auto c = builder.Reshape(a, {6}); - auto d = builder.Reshape(b, {6}); - auto e = builder.Mul(c, d); + auto a = ConstantR2(&builder, {{2, 3, 5}, {7, 11, 13}}); + auto b = ConstantR2(&builder, {{17, 19}, {23, 29}, {31, 37}}); + auto c = Reshape(a, {6}); + auto d = Reshape(b, {6}); + Mul(c, d); ComputeAndCompareR1(&builder, {34, 57, 115, 203, 341, 481}, {}); } diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index a4580cd71d46ad0a0186eddd51291f9c322b6f49..d3d6c3c7d703161e433740acbbd58d51ba1434af 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -59,7 +59,7 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -72,7 +72,7 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = Literal::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -85,7 +85,7 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = Literal::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -101,8 +101,8 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{}); + auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(1.0f); @@ -117,34 +117,28 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); - auto a = builder.Neg(parameter); - builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + auto a = Neg(parameter); + Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = Literal::CreateR1({-1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { +XLA_TEST_P(ReshapeTest, Trivial0x3) { XlaBuilder builder(TestName()); Array2D input_array(0, 3); auto input_literal = Literal::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = @@ -152,23 +146,20 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { +XLA_TEST_P(ReshapeTest, Trivial3x0) { XlaBuilder builder(TestName()); Array2D input_array(3, 0); auto input_literal = Literal::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -181,7 +172,7 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -194,25 +185,21 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Splits an empty vector into an empty matrix. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { +XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, - /*new_sizes=*/{2, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); auto expected_literal = Literal::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -226,27 +213,23 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, - /*new_sizes=*/{2, 3}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); auto expected_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { +XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); auto expected_literal = Literal::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -260,8 +243,8 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = Literal::CreateFromArray(*expected); @@ -277,8 +260,8 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = Literal::CreateFromArray(*expected); @@ -286,18 +269,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 0x4 array with XlaBuilder::Transpose. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { +XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Transpose(parameter, {1, 0}); + Transpose(parameter, {1, 0}); auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -311,7 +290,7 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Transpose(parameter, {1, 0}); + Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = Literal::CreateFromArray(*expected); @@ -319,36 +298,29 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -363,8 +335,8 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = Literal::CreateFromArray(*expected); @@ -372,18 +344,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); @@ -398,8 +366,8 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = Literal::CreateFromArray(expected); @@ -424,8 +392,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); auto expected_literal = Literal::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); @@ -439,8 +407,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); auto expected_literal = Literal::CreateR2({{10, 11, 12}, {15, 16, 17}, {20, 21, 22}, @@ -459,8 +427,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); auto expected_literal = Literal::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); @@ -474,8 +442,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); auto expected_literal = Literal::CreateR2({{10, 20, 30}, {40, 11, 21}, {31, 41, 12}, @@ -494,8 +462,8 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); auto expected_literal = Literal::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); @@ -527,7 +495,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = Literal::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, @@ -552,8 +520,8 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); auto expected_literal = Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); @@ -575,7 +543,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); - b.Reshape(parameter, dimensions, {}); + Reshape(parameter, dimensions, {}); auto expected_literal = Literal::CreateR0(83.0f); ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, @@ -589,7 +557,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); - b.Reshape(parameter, {}, {}); + Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), ::testing::HasSubstr("not a permutation of the operand dimensions")); @@ -601,7 +569,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); - b.Reshape(parameter, {1}, {}); + Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } @@ -637,7 +605,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, @@ -671,7 +639,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off auto expected_literal = Literal::CreateR4({ @@ -698,7 +666,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off auto expected_literal = Literal::CreateR4({ @@ -728,7 +696,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); @@ -750,7 +718,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); @@ -773,8 +741,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, - /*new_sizes=*/{5, 60}); + Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { @@ -800,8 +768,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, - /*new_sizes=*/{7, 2, 3, 5}); + Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, + /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; @@ -833,8 +801,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{1, 2, 3, 4}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } @@ -848,8 +816,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, - /*new_sizes=*/{2, 4, 3, 1}); + Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, + /*new_sizes=*/{2, 4, 3, 1}); // clang-format off auto expected_2x4x3x1 = Literal::CreateR4( @@ -882,8 +850,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -911,8 +879,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -940,8 +908,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -970,8 +938,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -999,8 +967,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index e7bd142dc9ddefbd8bebfb77d72218d662645c31..662bc42224851ac19c690129f525953e6d410a55 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -87,7 +87,7 @@ TEST_P(FloatReverseTest, Reverses) { XlaBuilder builder(TestName()); auto a = AddParam(*input_literal, &builder); - builder.Rev(a, spec.reversal); + Rev(a, spec.reversal); std::unique_ptr expected = input_literal->CloneToUnique(); std::vector output_indices(spec.input_dims.size()); @@ -127,7 +127,7 @@ XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { }}); // clang-format on - b.Rev(b.ConstantR4FromArray4D(input), {0, 3}); + Rev(ConstantR4FromArray4D(&b, input), {0, 3}); // clang-format off Array4D expected({{ @@ -163,7 +163,7 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { }); // clang-format on - b.Rev(b.ConstantR4FromArray4D(input), {0, 1}); + Rev(ConstantR4FromArray4D(&b, input), {0, 1}); // clang-format off Array4D expected({ diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 308d3fc78a51e63c0e3db8c0cda18caf11f665bd..bc994315c3c725e3c0a860b8016126a03ae73f58 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -44,74 +44,75 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template - void TestCompare( - NativeT lhs, NativeT rhs, bool expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(NativeT lhs, NativeT rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { XlaBuilder builder(TestName()); - builder.ConstantR0(2.1f); + ConstantR0(&builder, 2.1f); ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(2.1f)); + Neg(ConstantR0(&builder, 2.1f)); ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(2)); + Neg(ConstantR0(&builder, 2)); ComputeAndCompareR0(&builder, -2, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + Add(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(2), builder.ConstantR0(5)); + Add(ConstantR0(&builder, 2), ConstantR0(&builder, 5)); ComputeAndCompareR0(&builder, 7, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + Add(ConstantR0(&builder, 35), ConstantR0(&builder, 57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + Add(ConstantR0(&builder, 35), ConstantR0(&builder, 57)); ComputeAndCompareR0(&builder, 92, {}); } @@ -120,7 +121,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { XlaBuilder builder(TestName()); const uint64 a = static_cast(1) << 63; const uint64 b = a + 1; - builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + Add(ConstantR0(&builder, a), ConstantR0(&builder, b)); ComputeAndCompareR0(&builder, a + b, {}); } @@ -129,37 +130,36 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { XlaBuilder builder(TestName()); const int64 a = static_cast(1) << 62; const int64 b = a - 1; - builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + Add(ConstantR0(&builder, a), ConstantR0(&builder, b)); ComputeAndCompareR0(&builder, a + b, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(0.25), - builder.ConstantR0(3.5)); + Add(ConstantR0(&builder, 0.25), ConstantR0(&builder, 3.5)); ComputeAndCompareR0(&builder, 3.75, {}); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Sub(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + Sub(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { XlaBuilder builder(TestName()); - builder.Sub(builder.ConstantR0(2), builder.ConstantR0(5)); + Sub(ConstantR0(&builder, 2), ConstantR0(&builder, 5)); ComputeAndCompareR0(&builder, -3, {}); } XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { XlaBuilder builder(TestName()); - auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); - builder.ConvertElementType(a, F32); + auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a"); + ConvertElementType(a, F32); int64 value = 3LL << 35; std::unique_ptr a_literal = Literal::CreateR0(value); @@ -171,9 +171,8 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { XlaBuilder builder(TestName()); - builder.Mul(builder.Mul(builder.ConstantR0(2.1f), - builder.ConstantR0(5.5f)), - builder.ConstantR0(0.5f)); + Mul(Mul(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)), + ConstantR0(&builder, 0.5f)); ComputeAndCompareR0(&builder, 5.775f, {}, error_spec_); } @@ -190,7 +189,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { for (int32 x : data) { for (int32 y : data) { XlaBuilder builder(TestName()); - builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + Mul(ConstantR0(&builder, x), ConstantR0(&builder, y)); // Signed integer overflow is undefined behavior in C++. Convert the input // integers to unsigned, perform the multiplication unsigned, and convert @@ -209,7 +208,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { for (uint32 x : data) { for (uint32 y : data) { XlaBuilder builder(TestName()); - builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + Mul(ConstantR0(&builder, x), ConstantR0(&builder, y)); uint32 expected = x * y; ComputeAndCompareR0(&builder, expected, {}); @@ -219,9 +218,8 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XlaBuilder builder(TestName()); - builder.Mul( - builder.Mul(builder.ConstantR0(2), builder.ConstantR0(5)), - builder.ConstantR0(1)); + Mul(Mul(ConstantR0(&builder, 2), ConstantR0(&builder, 5)), + ConstantR0(&builder, 1)); ComputeAndCompareR0(&builder, 10, {}); } @@ -239,10 +237,10 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { std::unique_ptr c_data = client_->TransferToServer(*c_literal).ConsumeValueOrDie(); - XlaOp a = builder.Parameter(0, a_literal->shape(), "a"); - XlaOp b = builder.Parameter(1, b_literal->shape(), "b"); - XlaOp c = builder.Parameter(2, c_literal->shape(), "c"); - builder.Mul(builder.Mul(a, b), c); + XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, {a_data.get(), b_data.get(), c_data.get()}, @@ -251,14 +249,14 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Div(builder.ConstantR0(5.0f), builder.ConstantR0(2.5f)); + Div(ConstantR0(&builder, 5.0f), ConstantR0(&builder, 2.5f)); ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(2.5f), builder.ConstantR0(5.0f)); + Rem(ConstantR0(&builder, 2.5f), ConstantR0(&builder, 5.0f)); ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); } @@ -281,8 +279,8 @@ class DivS32Test : public ClientLibraryTestBase, XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { DivS32Params p = GetParam(); XlaBuilder builder(TestName()); - builder.Div(builder.ConstantR0(p.dividend), - builder.ConstantR0(p.divisor)); + Div(ConstantR0(&builder, p.dividend), + ConstantR0(&builder, p.divisor)); ComputeAndCompareR0(&builder, p.quotient, {}); } @@ -290,8 +288,8 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { DivS32Params p = GetParam(); XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(p.dividend), - builder.ConstantR0(p.divisor)); + Rem(ConstantR0(&builder, p.dividend), + ConstantR0(&builder, p.divisor)); ComputeAndCompareR0(&builder, p.remainder, {}); } @@ -305,7 +303,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR0(&builder, p.quotient, {dividendd.get(), divisord.get()}); @@ -320,7 +318,7 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR0(&builder, p.remainder, {dividendd.get(), divisord.get()}); @@ -367,10 +365,10 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { XlaBuilder builder(TestName()); XlaOp dividend = - builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend"); XlaOp divisor = - builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); - builder.Div(dividend, divisor); + Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor"); + Div(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); } @@ -408,10 +406,10 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { XlaBuilder builder(TestName()); XlaOp dividend = - builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend"); XlaOp divisor = - builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); - builder.Rem(dividend, divisor); + Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor"); + Rem(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); } @@ -439,8 +437,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); - builder.Rem(x, builder.ConstantR0(80000)); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); + Rem(x, ConstantR0(&builder, 80000)); std::unique_ptr literal = Literal::CreateR0(87919); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); @@ -451,15 +449,15 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { XlaBuilder builder(TestName()); // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF). - builder.Div(builder.ConstantR0(0xFFFFFFFE), - builder.ConstantR0(2)); + Div(ConstantR0(&builder, 0xFFFFFFFE), + ConstantR0(&builder, 2)); ComputeAndCompareR0(&builder, 0x7FFFFFFF, {}); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(11), builder.ConstantR0(3)); + Rem(ConstantR0(&builder, 11), ConstantR0(&builder, 3)); ComputeAndCompareR0(&builder, 2, {}); } @@ -468,7 +466,7 @@ XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x && y, {}); } @@ -479,7 +477,7 @@ XLA_TEST_F(ScalarComputationsTest, AndS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x & y, {}); } @@ -490,7 +488,7 @@ XLA_TEST_F(ScalarComputationsTest, AndU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x & y, {}); } @@ -501,7 +499,7 @@ XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x || y, {}); } @@ -512,7 +510,7 @@ XLA_TEST_F(ScalarComputationsTest, OrS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x | y, {}); } @@ -523,7 +521,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x | y, {}); } @@ -533,7 +531,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, !x, {}); } @@ -542,7 +540,7 @@ XLA_TEST_F(ScalarComputationsTest, NotBool) { XLA_TEST_F(ScalarComputationsTest, NotS32) { for (int32 x : {-1, 0, 1}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, ~x, {}); } @@ -551,7 +549,7 @@ XLA_TEST_F(ScalarComputationsTest, NotS32) { XLA_TEST_F(ScalarComputationsTest, NotU32) { for (uint32 x : {0, 1, 2}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, ~x, {}); } @@ -559,18 +557,18 @@ XLA_TEST_F(ScalarComputationsTest, NotU32) { XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { XlaBuilder builder(TestName()); - builder.Select(builder.ConstantR0(true), // The predicate. - builder.ConstantR0(123.0f), // The value on true. - builder.ConstantR0(42.0f)); // The value on false. + Select(ConstantR0(&builder, true), // The predicate. + ConstantR0(&builder, 123.0f), // The value on true. + ConstantR0(&builder, 42.0f)); // The value on false. ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { XlaBuilder builder(TestName()); - builder.Select(builder.ConstantR0(false), // The predicate. - builder.ConstantR0(123.0f), // The value on true. - builder.ConstantR0(42.0f)); // The value on false. + Select(ConstantR0(&builder, false), // The predicate. + ConstantR0(&builder, 123.0f), // The value on true. + ConstantR0(&builder, 42.0f)); // The value on false. ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); } @@ -579,313 +577,311 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { // templatized comparison tests. XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { XlaBuilder builder(TestName()); - builder.Gt(builder.ConstantR0(2.0f), builder.ConstantR0(1.0f)); + Gt(ConstantR0(&builder, 2.0f), ConstantR0(&builder, 1.0f)); ComputeAndCompareR0(&builder, true, {}); } // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Eq); + TestCompare(3, 3, true, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); TestCompare(std::numeric_limits::min(), - std::numeric_limits::max(), true, &XlaBuilder::Lt); + std::numeric_limits::max(), true, &Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Ge); + TestCompare(3, 3, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); - TestCompare(5, 5, false, &XlaBuilder::Gt); - TestCompare(5, 1, true, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); + TestCompare(5, 5, false, &Gt); + TestCompare(5, 1, true, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); - TestCompare(0, std::numeric_limits::max(), true, - &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); + TestCompare(0, std::numeric_limits::max(), true, &Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare(2.0, 1.3, false, &XlaBuilder::Eq); + TestCompare(2.0, 1.3, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare(2.0, 1.3, true, &XlaBuilder::Ne); + TestCompare(2.0, 1.3, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare(2.0, 1.9, true, &XlaBuilder::Ge); + TestCompare(2.0, 1.9, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare(3.5, 3.5, true, &XlaBuilder::Ge); + TestCompare(3.5, 3.5, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare(1.0, 5.2, false, &XlaBuilder::Gt); + TestCompare(1.0, 5.2, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare(2.0, 1.2, false, &XlaBuilder::Le); + TestCompare(2.0, 1.2, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare(9.0, 7.2, false, &XlaBuilder::Lt); + TestCompare(9.0, 7.2, false, &Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare(-INFINITY, -0.0, true, &XlaBuilder::Lt); + TestCompare(-INFINITY, -0.0, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, false, &XlaBuilder::Lt); + TestCompare(-0.0, 0.0, false, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare(0.0, INFINITY, true, &XlaBuilder::Lt); + TestCompare(0.0, INFINITY, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare(-INFINITY, -0.0, false, &XlaBuilder::Ge); + TestCompare(-INFINITY, -0.0, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, true, &XlaBuilder::Ge); + TestCompare(-0.0, 0.0, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare(0.0, INFINITY, false, &XlaBuilder::Ge); + TestCompare(0.0, INFINITY, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { XlaBuilder builder(TestName()); - builder.Exp(builder.ConstantR0(2.0f)); + Exp(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, LogScalar) { XlaBuilder builder("log"); - builder.Log(builder.ConstantR0(2.0f)); + Log(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhScalar) { XlaBuilder builder(TestName()); - builder.Tanh(builder.ConstantR0(2.0f)); + Tanh(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { XlaBuilder builder(TestName()); - builder.Tanh(builder.ConstantR0(2.0)); + Tanh(ConstantR0(&builder, 2.0)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, PowScalar) { XlaBuilder builder(TestName()); - builder.Pow(builder.ConstantR0(2.0f), builder.ConstantR0(3.0f)); + Pow(ConstantR0(&builder, 2.0f), ConstantR0(&builder, 3.0f)); ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, 5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 3, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(2), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, 2), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 2, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(-5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, -5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, -1, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 3, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(2), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 2), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 2, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(0), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 0), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 1, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(5.0f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, 5.0f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(2.5f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, 2.5f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(-5.0f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, -5.0f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 2.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax(10, 3, 3, &XlaBuilder::Min); + TestMinMax(10, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax(-100, 3, -100, &XlaBuilder::Min); + TestMinMax(-100, 3, -100, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax(10, 3, 10, &XlaBuilder::Max); + TestMinMax(10, 3, 10, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax(-100, 3, 3, &XlaBuilder::Max); + TestMinMax(-100, 3, 3, &Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, 3, &XlaBuilder::Min); + TestMinMax(large, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax(0, 5, 0, &XlaBuilder::Min); + TestMinMax(0, 5, 0, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, large, &XlaBuilder::Max); + TestMinMax(large, 3, large, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax(0, 5, 5, &XlaBuilder::Max); + TestMinMax(0, 5, 5, &Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); + TestMinMax(10.1f, 3.1f, 3.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); + TestMinMax(-100.1f, 3.1f, -100.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Min); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Min); + TestMinMax(NAN, 3.1f, NAN, &Min); + TestMinMax(-3.1f, NAN, NAN, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); + TestMinMax(10.1f, 3.1f, 10.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); + TestMinMax(-100.1f, 3.1f, 3.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Max); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Max); + TestMinMax(NAN, 3.1f, NAN, &Max); + TestMinMax(-3.1f, NAN, NAN, &Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. XlaBuilder b(TestName()); - b.Div( - b.Sub(b.Mul(b.ConstantR0(1), - b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), - b.Add(b.ConstantR0(7), b.ConstantR0(0)))), - b.ConstantR0(4)), - b.ConstantR0(20)); + Div(Sub(Mul(ConstantR0(&b, 1), + Mul(Sub(ConstantR0(&b, 3), ConstantR0(&b, 1)), + Add(ConstantR0(&b, 7), ConstantR0(&b, 0)))), + ConstantR0(&b, 4)), + ConstantR0(&b, 20)); ComputeAndCompareR0(&b, 0.5, {}, error_spec_); } @@ -893,10 +889,10 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { // Compute the expression 1 * (3 - 1) * (7 + 0) - 4. XlaBuilder b(TestName()); - b.Sub(b.Mul(b.ConstantR0(1), - b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), - b.Add(b.ConstantR0(7), b.ConstantR0(0)))), - b.ConstantR0(4)); + Sub(Mul(ConstantR0(&b, 1), + Mul(Sub(ConstantR0(&b, 3), ConstantR0(&b, 1)), + Add(ConstantR0(&b, 7), ConstantR0(&b, 0)))), + ConstantR0(&b, 4)); ComputeAndCompareR0(&b, 10, {}); } @@ -908,15 +904,15 @@ XLA_TEST_F(ScalarComputationsTest, SqrtF320) { std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - XlaOp zero = builder.Parameter(0, zero_literal.shape(), "zero"); - builder.SqrtF32(zero); + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + SqrtF32(zero); ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RoundScalar) { XlaBuilder builder(TestName()); - builder.Round(builder.ConstantR0(1.4f)); + Round(ConstantR0(&builder, 1.4f)); ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 7015e5a6a31f506d30c2629d7735482cf354455a..0a173fbbbd5cb5e5005728331561008b8b29af26 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -73,16 +73,16 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { auto operand_shape = GetParam().operand_shape; Array o(operand_shape); o.FillRandom(1.5f); - auto operand = builder_.ConstantFromArray(o); + auto operand = ConstantFromArray(&builder_, o); auto source_shape = GetParam().source_shape; Array s(source_shape); s.FillRandom(12.0f); - auto source = builder_.ConstantFromArray(s); + auto source = ConstantFromArray(&builder_, s); - builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, - GetParam().window_strides, GetParam().padding_type, - source, builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, + GetParam().window_strides, GetParam().padding_type, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5)); } @@ -197,110 +197,110 @@ INSTANTIATE_TEST_CASE_P( // Test for F32 1D array, with a zero-element input. XLA_TEST_F(SelectAndScatterTest, R1S0F32) { - const auto operand = builder_.ConstantR1({}); - const auto source = builder_.ConstantR1({}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + const auto operand = ConstantR1(&builder_, {}); + const auto source = ConstantR1(&builder_, {}); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR1(&builder_, {}, {}, ErrorSpec(1e-7)); } // Test for F32 1D array, when windows do not overlap. XLA_TEST_F(SelectAndScatterTest, R1F32) { const auto operand = - builder_.ConstantR1({1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); - const auto source = builder_.ConstantR1({34.f, 42.f}); + ConstantR1(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); + const auto source = ConstantR1(&builder_, {34.f, 42.f}); const std::vector expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } // Test for S32 1D array, when windows do not overlap and the init value is 1. XLA_TEST_F(SelectAndScatterTest, R1S32) { - const auto operand = builder_.ConstantR1({-1, 0, 6, 4, -4, 10}); - const auto source = builder_.ConstantR1({-10, 20}); + const auto operand = ConstantR1(&builder_, {-1, 0, 6, 4, -4, 10}); + const auto source = ConstantR1(&builder_, {-10, 20}); const std::vector expected = {1, 1, -9, 1, 1, 21}; - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(1), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 1), add_s32_); ComputeAndCompareR1(&builder_, expected, {}); } // Test for S32 1D array, when windows overlap with each other. XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { - const auto operand = builder_.ConstantR1({1, 9, 3, 7, 5, 6}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 9, 3, 7, 5, 6}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 76, 0, 72, 0, 0}; - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR1(&builder_, expected, {}); } // Test for S32 2D array, when windows do not overlap. XLA_TEST_F(SelectAndScatterTest, R2S32) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); - const auto source = builder_.ConstantR2({{2, 6}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); + const auto source = ConstantR2(&builder_, {{2, 6}}); Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{2, 3}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for tie breaking rule in ge_f32_. When a tie is present, the operand // that has the lower lexicographical order (smaller index) should be chosen. XLA_TEST_F(SelectAndScatterTest, R2F32Tie) { - const auto operand = builder_.ConstantR2( - {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); - const auto source = builder_.ConstantR2( - {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + const auto operand = ConstantR2( + &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); + const auto source = ConstantR2( + &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); Array2D expected( {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, - /*window_strides=*/{1, 1}, Padding::kSame, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); } // Similar to SelectAndScatterTest.R2S32 but the input is transposed. XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { - const auto operand = builder_.ConstantR2( - {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); + const auto operand = ConstantR2( + &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); const auto reshape = - builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); - const auto source = builder_.ConstantR2({{2, 6}}); + Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); + const auto source = ConstantR2(&builder_, {{2, 6}}); Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); - builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{2, 3}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for S32 2D array, when windows overlap with each other. XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); - const auto source = builder_.ConstantR2({{2, 6, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = ConstantR2(&builder_, {{2, 6, 4}}); Array2D expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{1, 1}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for S32 2D array, when the padding is Padding::kSAME. XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); - const auto source = builder_.ConstantR2({{2, 6, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = ConstantR2(&builder_, {{2, 6, 4}}); Array2D expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{2, 2}, Padding::kSame, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{2, 2}, Padding::kSame, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } @@ -308,25 +308,26 @@ XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { // with each other. XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); const auto source = - builder_.ConstantR2({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); + ConstantR2(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); Array2D expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{1, 1}, Padding::kSame, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) { - const auto operand = builder_.ConstantR2( - {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); - const auto source = builder_.ConstantR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + const auto operand = ConstantR2( + &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); + const auto source = + ConstantR2(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}}); Array2D expected( {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{1, 1}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); } @@ -342,16 +343,16 @@ TEST_F(SelectAndScatterTest, R4F32Valid) { {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}}; Array4D o(4, 6, 15, 220); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 6, 15, 220); e.FillWithPZ(pze); Array4D s(2, 2, 15, 220); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -367,16 +368,16 @@ TEST_F(SelectAndScatterTest, R4F32Overlap) { {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; Array4D o(4, 5, 17, 128); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 5, 17, 128); e.FillWithPZ(pze); Array4D s(2, 2, 17, 128); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -392,16 +393,16 @@ TEST_F(SelectAndScatterTest, R4F32OverlapSmall) { {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; Array4D o(4, 5, 1, 1); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 5, 1, 1); e.FillWithPZ(pze); Array4D s(2, 2, 1, 1); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -414,39 +415,39 @@ TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) { Array2D pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; Array4D o(4, 6, 4, 4); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D s(2, 2, 4, 4); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1}, {2, 3, 1, 1}, false); ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); } XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { - const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 0, 0, 53, 0, 0, 0}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(0), max_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, 0), max_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { - const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const float max_float = std::numeric_limits::max(); const std::vector expected = {max_float, max_float, max_float, 19, max_float, max_float, max_float}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(max_float), min_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, max_float), min_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 72707f224446c7585d1d90ac6681a7b38c41d5f1..59409ab26e1c19a8271318c18e19caa7b8ddc3b7 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -35,50 +35,52 @@ class SelectTest : public ClientLibraryTestBase { TEST_F(SelectTest, SelectScalarF32True) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR0(123.0f); - auto on_false = builder.ConstantR0(42.0f); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR0(&builder, 123.0f); + auto on_false = ConstantR0(&builder, 42.0f); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); } TEST_F(SelectTest, SelectScalarS32True) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR0(-42); - auto on_false = builder.ConstantR0(42); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR0(&builder, -42); + auto on_false = ConstantR0(&builder, 42); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, -42, {}); } TEST_F(SelectTest, SelectScalarF32False) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto on_true = builder.ConstantR0(123.0f); - auto on_false = builder.ConstantR0(42.0f); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, false); + auto on_true = ConstantR0(&builder, 123.0f); + auto on_false = ConstantR0(&builder, 42.0f); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR1({}); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR1(&builder, {}); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR1({false, true, false, true, false}); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR1(&builder, {false, true, false, true, false}); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -88,12 +90,12 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector // is not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({}); - auto v2 = builder.ConstantR1({}); - auto cmp = builder.Eq(v1, v2); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {}); + auto v2 = ConstantR1(&builder, {}); + auto cmp = Eq(v1, v2); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -102,12 +104,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is // not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({1, 2, 3, 4, 5}); - auto v2 = builder.ConstantR1({9, 2, 9, 4, 9}); - auto cmp = builder.Eq(v1, v2); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {1, 2, 3, 4, 5}); + auto v2 = ConstantR1(&builder, {9, 2, 9, 4, 9}); + auto cmp = Eq(v1, v2); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -116,12 +120,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto v2 = builder.ConstantR1({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); - auto cmp = builder.Gt(v1, v2); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto v2 = ConstantR1(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); + auto cmp = Gt(v1, v2); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {}, error_spec_); @@ -140,8 +146,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + auto cmp = Gt(v1, v2); + Select(cmp, v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -181,8 +187,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + auto cmp = Gt(v1, v2); + Select(cmp, v1, v2); ComputeAndCompareR1(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -192,14 +198,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, -1, 2, -2}); - auto s = builder.ConstantR0(0); - auto cmp = builder.Gt(v, s); + auto v = ConstantR1(&builder, {1, -1, 2, -2}); + auto s = ConstantR0(&builder, 0); + auto cmp = Gt(v, s); - auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = - builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {}, error_spec_); @@ -209,14 +215,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f}); - auto s = builder.ConstantR0(2.5f); - auto cmp = builder.Gt(v, s); + auto v = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f}); + auto s = ConstantR0(&builder, 2.5f); + auto cmp = Gt(v, s); - auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = - builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {}, error_spec_); @@ -225,10 +231,10 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { for (bool which : {false, true}) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(which); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, which); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -236,20 +242,20 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR1({-2.5f, 25.5f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); + auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto on_true = builder.ConstantR1({-2.5f, 25.5f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, false); + auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); + auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 5.0f}, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 5653bf11a7364bf9ed79bcb6b53f7db31f454803..3e5c01d6d47cc3f3b7d46ce300fe26c5ec9e63fa 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -42,8 +42,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); Array3D expected{ {{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}}; @@ -55,8 +55,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); Array3D expected{ {{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}}; @@ -68,8 +68,8 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); Array3D expected{ {{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}}; @@ -78,24 +78,24 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); - builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + Slice(original, {0, 0}, {0, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 0), {}); } XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); - builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(0, 20)); + Slice(original, {0, 15}, {0, 20}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 5), {}); } XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); - builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(3, 0)); + Slice(original, {1, 0}, {3, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(2, 0), {}); } @@ -109,8 +109,8 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { } XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {128, 128}, {256, 256}, {1, 1}); Array2D expected(128, 128); for (int row = 0; row < 128; ++row) { @@ -127,8 +127,8 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { std::iota(values.data(), values.data() + 4096, 0.0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {0, 3072}, {1, 4096}, {1, 1}); Array2D expected(1, 1024); std::iota(expected.data(), expected.data() + 1024, 3072.0); @@ -148,8 +148,8 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } } XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); } @@ -160,8 +160,8 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { auto expected = ReferenceUtil::Slice4D( values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); XlaBuilder builder(TestName()); - auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); + auto original = ConstantR4FromArray4D(&builder, values); + Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } @@ -173,8 +173,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { auto expected_literal = Literal::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); - auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); + auto original = ConstantR4FromArray4D(&builder, values); + Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), &expected_literal->shape()); } @@ -200,9 +200,9 @@ class SliceR1Test : public ClientLibraryTestBase, auto literal = Literal::CreateR1(input); XlaBuilder builder(TestName()); - auto original = builder.Parameter(0, literal->shape(), "p0"); - builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, - {spec.slice_stride}); + auto original = Parameter(&builder, 0, literal->shape(), "p0"); + Slice(original, {spec.slice_start}, {spec.slice_limit}, + {spec.slice_stride}); // Ditto. tensorflow::gtl::InlinedVector expected; @@ -372,8 +372,8 @@ XLA_TEST_P(SliceR2Test, DoIt) { input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = builder.Parameter(0, literal->shape(), "p0"); - builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + auto a = Parameter(&builder, 0, literal->shape(), "p0"); + Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, client_->TransferToServer(*literal)); @@ -465,11 +465,10 @@ class SliceR4Test : public ClientLibraryTestBase, XlaBuilder builder(TestName()); auto literal = Literal::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = builder.Parameter(0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, client_->TransferToServer(*literal)); - builder.Slice(parameter, spec.slice_starts, spec.slice_limits, - spec.slice_strides); + Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } }; diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index dd7c541733634213606b5a7983b59bb1f14bf75c..20c7c30878a2821915d47bcf9fa1cc53907df9da 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -161,6 +161,9 @@ StatusOr> MakeFakeLiteralInternal( })); break; } + // Token requires no data. + case TOKEN: + break; default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape).c_str()); @@ -270,14 +273,22 @@ StatusOr> CreateLiteralForConstrainedUses( switch (use->opcode()) { case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr && - !ShapeUtil::Equal(needs_index->shape(), use->shape())) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + if (needs_index != nullptr) { + auto needs_index_shape = needs_index->shape(); + auto use_shape = use->shape(); + if (needs_index->opcode() == HloOpcode::kDynamicSlice) { + needs_index_shape = needs_index->operand(0)->shape(); + } + if (use->opcode() == HloOpcode::kDynamicSlice) { + use_shape = use->operand(0)->shape(); + } + if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } } needs_index = use; break; - case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = use; diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 59afd28a80c0fbf3df38457cd05961c883769856..8f424ae81f592bfd8accd8decb8fc363f7561c73 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -31,16 +32,16 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { XlaBuilder builder(TestName()); // Make the reduction lambda. Shape single_float = ShapeUtil::MakeShape(F32, {}); - builder.Parameter(0, single_float, "unused"); - builder.Parameter(1, single_float, "used"); + Parameter(&builder, 0, single_float, "unused"); + Parameter(&builder, 1, single_float, "used"); auto computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); // Make the reduction. Shape pair_float = ShapeUtil::MakeShape(F32, {2}); - builder.Reduce(builder.Parameter(0, pair_float, "operand"), - builder.Parameter(1, single_float, "init"), - computation_status.ValueOrDie(), {0}); + Reduce(Parameter(&builder, 0, pair_float, "operand"), + Parameter(&builder, 1, single_float, "init"), + computation_status.ValueOrDie(), {0}); computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); @@ -53,5 +54,23 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { TF_ASSERT_OK(MakeFakeArguments(&module).status()); } +XLA_TEST_F(TestUtilsTest, Token) { + auto module = ParseHloString( + R"(HloModule outfeed_module + + ENTRY InfeedToOutfeed { + token = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) + })") + .ValueOrDie(); + TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 3ef54e6f89251bbd6dba0705698c6627c554791e..e9008fa48aa7d0158bd2221791be23c128859098 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -31,27 +31,29 @@ class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction(HloInstruction::CreateAfterAll({})); module->AddEntryComputation(builder.Build()); - EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - builder.AddInstruction( - HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + auto token0 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateAfterAll({token0, token0, token1, token2})); module->AddEntryComputation(builder.Build()); - EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -89,24 +91,12 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); } -XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { - std::unique_ptr module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - module->AddEntryComputation(builder.Build()); - - Status status = HloVerifier().Run(module.get()).status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), - ::testing::HasSubstr("Entry root is or contains a token shape")); -} - XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction(HloInstruction::CreateAfterAll({param})); builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(123))); module->AddEntryComputation(builder.Build()); @@ -120,7 +110,7 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { // Thread a token around a while loop. Token is created and consumed by a - // GenerateToken instruction in the while body. + // AfterAll instruction in the while body. string module_string = R"( HloModule TokenInWhileLoop @@ -130,8 +120,8 @@ HloModule TokenInWhileLoop %constant.1 = s32[] constant(1) %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %generate-token = token[] generate-token(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) } %Cond (param: (s32[], token[])) -> pred[] { @@ -143,14 +133,73 @@ HloModule TokenInWhileLoop ENTRY %TokenInWhileLoop () -> s32[] { %zero = s32[] constant(0) - %init_token = token[] generate-token() + %init_token = token[] after-all() %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 } )"; - EXPECT_TRUE(RunAndCompare(module_string, error_spec_)); + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + + EXPECT_TRUE(RunAndCompare(std::move(module), error_spec_)); +} + +XLA_TEST_F(TokenHloTest, TokenInConditional) { + string module_string = R"( +HloModule TokenInConditional + +%True (param.1: token[]) -> (s32[], token[]) { + %param.1 = token[] parameter(0) + %forty_two = s32[] constant(42) + ROOT %tuple = (s32[], token[]) tuple(s32[] %forty_two, token[] %param.1) +} + +%False (param.2: s32[]) -> (s32[], token[]) { + %param.2 = s32[] parameter(0) + %new_token = token[] after-all() + ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token) +} + +ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { + %param.3 = pred[] parameter(0) + %init_token = token[] after-all() + %seven = s32[] constant(7) + %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False + ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0 +} +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + + { + // True case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = Literal::CreateR0(true); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(42, result->Get({})); + } + + { + // False case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = Literal::CreateR0(false); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(7, result->Get({})); + } } } // namespace diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 0063e7ad415e9b6718c164f415ced6fb76cbf44a..86babb58c9d4515935a5904e04e8fea1074a2812 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -41,7 +42,12 @@ class TransferManagerTest : public LocalClientTestBase { TransferManagerTest() : shape_size_fn_([this](const Shape& shape) { return transfer_manager_->GetByteSizeRequirement(shape); - }) {} + }) { + stream_ptr_ = local_client_->mutable_backend() + ->BorrowStream(stream_executor_) + .ValueOrDie(); + stream_ = stream_ptr_.get(); + } ~TransferManagerTest() override = default; @@ -53,6 +59,10 @@ class TransferManagerTest : public LocalClientTestBase { .ValueOrDie(); } + protected: + Backend::StreamPtr stream_ptr_; + se::Stream* stream_; + private: std::function shape_size_fn_; }; @@ -63,11 +73,11 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR0Equal(42, *result); } @@ -79,11 +89,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, *result); @@ -97,11 +107,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal(test_vector, *result); } @@ -113,11 +123,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_EQ(result->GetR1U8AsString(), test_string); } @@ -129,11 +139,11 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); @@ -149,11 +159,11 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); @@ -169,11 +179,11 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -183,11 +193,11 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -203,11 +213,11 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -218,11 +228,11 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -237,14 +247,162 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } +XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { + // "Copy" a token from the device. The token has no physical representation so + // no copying is actually performed, but it shouldn't fail. + // TODO(b/110532604): Add transferring the token to device when this is + // supported. + auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + EXPECT_TRUE(LiteralTestUtil::Equal(*Literal::CreateToken(), *result)); +} + +XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { + const int64 kIterationCount = 5000; + std::unique_ptr literal1 = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-10.0f, 123.0f}).get()}); + std::unique_ptr literal2 = Literal::MakeTuple( + {Literal::CreateR0(456.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), + Literal::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-98.0f, 153.0f}).get()}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + + auto stream1 = stream_; + auto stream2 = stream_->GetOrCreateSubStream(); + + std::unique_ptr result1, result2; + + // Round trip literals through device in multiple streams asynchronously. + for (int i = 0; i < kIterationCount; ++i) { + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + device_buffer1)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + device_buffer2)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result1, + transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result2, + transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); + result1 = std::move(this_result1); + result2 = std::move(this_result2); + } + + EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); +} + +class TransferDeviceToHostBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferDeviceToHostBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + Literal::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +class TransferHostToDeviceBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferHostToDeviceBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + Literal::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +void BM_TransferDeviceToHost(int iters, int num_tuple_elements, + int array_size) { + TransferDeviceToHostBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +void BM_TransferHostToDevice(int iters, int num_tuple_elements, + int array_size) { + TransferHostToDeviceBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +BENCHMARK(BM_TransferHostToDevice) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +BENCHMARK(BM_TransferDeviceToHost) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index fe1e3da7eca00e128377e6e56af877868aafa836..6ebb4324f8d20ed9f8886d92b0513441685ed19b 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -38,34 +38,35 @@ class TransposeTest : public ClientLibraryTestBase { XLA_TEST_F(TransposeTest, Transpose0x0) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); } XLA_TEST_F(TransposeTest, Transpose0x42) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 42)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 42)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(42, 0), {}, error_spec_); } XLA_TEST_F(TransposeTest, Transpose7x0) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(7, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(7, 0)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(0, 7), {}, error_spec_); } TEST_F(TransposeTest, Transpose2x2) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2({ - {1.0, 2.0}, {3.0, 4.0}, - }); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2(&builder, { + {1.0, 2.0}, + {3.0, 4.0}, + }); + Transpose(lhs, {1, 0}); Array2D expected({{1.0f, 3.0f}, {2.0f, 4.0f}}); @@ -74,16 +75,18 @@ TEST_F(TransposeTest, Transpose2x2) { XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D(Array3D(0, 2, 3)); - auto result = builder.Transpose(operand, {1, 2, 0}); + auto operand = + ConstantR3FromArray3D(&builder, Array3D(0, 2, 3)); + Transpose(operand, {1, 2, 0}); ComputeAndCompareR3(&builder, Array3D(2, 3, 0), {}); } TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {1, 2, 0}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {1, 2, 0}); Array3D expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}}); @@ -92,8 +95,9 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {2, 1, 0}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {2, 1, 0}); Array3D expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}}); @@ -102,8 +106,9 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {0, 1, 2}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {0, 1, 2}); Array3D expected({{{1, 2, 3}, {4, 5, 6}}}); @@ -116,9 +121,9 @@ TEST_F(TransposeTest, MultiTranspose3x2) { for (int transposes = 0; transposes <= 10; ++transposes) { XlaBuilder builder("Transpose"); - auto computed = builder.ConstantR2FromArray2D(input); + auto computed = ConstantR2FromArray2D(&builder, input); for (int i = 0; i < transposes; ++i) { - computed = builder.Transpose(computed, {1, 0}); + computed = Transpose(computed, {1, 0}); } const Array2D& expected = transposes % 2 == 0 ? input : transposed; ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -130,8 +135,8 @@ TEST_F(TransposeTest, Small_1x1) { auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); XlaBuilder builder("transpose_1x1"); - auto operand = builder.ConstantR2FromArray2D(*aoperand); - builder.Transpose(operand, {1, 0}); + auto operand = ConstantR2FromArray2D(&builder, *aoperand); + Transpose(operand, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*aoperand); ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); @@ -142,8 +147,8 @@ TEST_F(TransposeTest, Small_2x2) { auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); XlaBuilder builder("transpose_2x2"); - auto operand = builder.ConstantR2FromArray2D(*aoperand); - builder.Transpose(operand, {1, 0}); + auto operand = ConstantR2FromArray2D(&builder, *aoperand); + Transpose(operand, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*aoperand); ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); @@ -162,8 +167,8 @@ void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { } XlaBuilder builder(TestName()); - auto operand = builder.ConstantR3FromArray3D(aoperand); - builder.Transpose(operand, {0, 2, 1}); + auto operand = ConstantR3FromArray3D(&builder, aoperand); + Transpose(operand, {0, 2, 1}); ComputeAndCompareR3(&builder, expected, {}); } diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 41189231b90e842292830a932cf381af60456d4c..ec11508891d13f8032a1ebec388c756cf6d752c7 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -54,7 +54,7 @@ XLA_TEST_F(TupleTest, TupleConstant) { Literal::CreateR1(constant_vector).get(), Literal::CreateR2(constant_matrix).get()}); - builder.ConstantLiteral(*value); + ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } @@ -68,7 +68,7 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), Literal::CreateR0(constant_scalar2).get()}); - builder.ConstantLiteral(*value); + ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } @@ -82,9 +82,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - builder.Tuple({builder.ConstantR0(constant_scalar), - builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); + Tuple(&builder, {ConstantR0(&builder, constant_scalar), + ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); auto expected = Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), @@ -97,8 +97,8 @@ XLA_TEST_F(TupleTest, TupleCreate) { XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XlaBuilder builder(TestName()); - builder.Tuple( - {builder.ConstantR0(7.0), builder.ConstantR1({})}); + Tuple(&builder, + {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), Literal::CreateR1({}).get()}); @@ -108,7 +108,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { // Tests the creation of an empty tuple. XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); - builder.Tuple({}); + Tuple(&builder, {}); auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -121,9 +121,10 @@ XLA_TEST_F(TupleTest, GetTupleElement) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, error_spec_); } @@ -131,17 +132,18 @@ XLA_TEST_F(TupleTest, GetTupleElement) { // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { XlaBuilder builder(TestName()); - auto tuple_data = builder.Tuple( - {builder.ConstantR1({}), - builder.ConstantR2FromArray2D(Array2D(0, 101))}); - builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, + {ConstantR1(&builder, {}), + ConstantR2FromArray2D(&builder, Array2D(0, 101))}); + GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { XlaBuilder builder(TestName()); - auto value = builder.ConstantR1({4.5f}); - builder.GetTupleElement(value, 1); + auto value = ConstantR1(&builder, {4.5f}); + GetTupleElement(value, 1); auto result_status = builder.Build(); EXPECT_FALSE(result_status.ok()); EXPECT_THAT( @@ -158,14 +160,15 @@ XLA_TEST_F(TupleTest, AddTupleElements) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - auto vector_element = builder.GetTupleElement(tuple_data, 0); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + auto vector_element = GetTupleElement(tuple_data, 0); + auto matrix_element = GetTupleElement(tuple_data, 1); auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); - builder.Add(matrix_element, vector_element, - /*broadcast_dimensions=*/{1}); + Add(matrix_element, vector_element, + /*broadcast_dimensions=*/{1}); Array2D expected({ {2.f, 4.f, 6.f}, // row 0 @@ -185,10 +188,11 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + Tuple(&builder, + {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); auto expected = Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), Literal::CreateR1(constant_vector).get()}); @@ -206,11 +210,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { std::unique_ptr v2_data = CreateR0Parameter(1.0f, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&b, /*data_handle=*/&v2); - auto v1_gt = b.Gt(v1, v2); // false - auto v2_gt = b.Gt(v2, v1); // true - auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} - auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} - b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + auto v1_gt = Gt(v1, v2); // false + auto v2_gt = Gt(v2, v1); // true + auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} + auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} + Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = Literal::MakeTuple({Literal::CreateR0(direction).get(), Literal::CreateR0(!direction).get()}); @@ -243,22 +247,23 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0), - builder.GetTupleElement(tuple_data, 1)}); - auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); - auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0); - auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1); - auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1); - auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0); - - auto addvectors = builder.Add(vector_from_01, vector_from_10); - auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); - - builder.Add(addmatrices, addvectors, - /*broadcast_dimensions=*/{1}); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0), + GetTupleElement(tuple_data, 1)}); + auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1), + GetTupleElement(tuple_data, 0)}); + auto vector_from_01 = GetTupleElement(new_tuple01, 0); + auto vector_from_10 = GetTupleElement(new_tuple10, 1); + auto matrix_from_01 = GetTupleElement(new_tuple01, 1); + auto matrix_from_10 = GetTupleElement(new_tuple10, 0); + + auto addvectors = Add(vector_from_01, vector_from_10); + auto addmatrices = Add(matrix_from_01, matrix_from_10); + + Add(addmatrices, addvectors, + /*broadcast_dimensions=*/{1}); Array2D expected({ {4.f, 8.f, 12.f}, // row 0 @@ -273,12 +278,12 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + Select(ConstantR0(&builder, false), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -292,22 +297,22 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { // Need to put a select in there to prevent HLO-level optimizations from // optimizing out the tuples. XlaBuilder b("sort_square"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto x2 = b.Mul(x, x); - auto x_smaller_tuple = b.Tuple({x, x2}); - auto x2_smaller_tuple = b.Tuple({x2, x}); - auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); - auto smaller = b.GetTupleElement(sorted, 0); - auto greater = b.GetTupleElement(sorted, 1); - b.Add(greater, b.Mul(b.ConstantR0(100.0f), smaller)); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto x2 = Mul(x, x); + auto x_smaller_tuple = Tuple(&b, {x, x2}); + auto x2_smaller_tuple = Tuple(&b, {x2, x}); + auto sorted = Select(Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); + auto smaller = GetTupleElement(sorted, 0); + auto greater = GetTupleElement(sorted, 1); + Add(greater, Mul(ConstantR0(&b, 100.0f), smaller)); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); tuple_computation = computation_status.ConsumeValueOrDie(); } XlaBuilder b(TestName()); - auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); - b.Map({input}, tuple_computation, {0}); + auto input = ConstantR1(&b, {-1.0f, 1.0f, 2.1f}); + Map(&b, {input}, tuple_computation, {0}); ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); } @@ -317,12 +322,12 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - builder.Select(builder.ConstantR0(true), tuple12, tuple21); + Select(ConstantR0(&builder, true), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -335,14 +340,13 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); - builder.GetTupleElement(select, 0); + auto select = Select(ConstantR0(&builder, false), tuple12, tuple21); + GetTupleElement(select, 0); ComputeAndCompareR1(&builder, vec2, {}, error_spec_); } @@ -371,19 +375,16 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto pred_tuple = builder.Tuple( - {builder.ConstantR0(true), builder.ConstantR0(false)}); - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto pred_tuple = Tuple(&builder, {ConstantR0(&builder, true), + ConstantR0(&builder, false)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - auto select1 = - builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); - auto select2 = - builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); - builder.Add(builder.GetTupleElement(select2, 0), - builder.GetTupleElement(select2, 1)); + auto select1 = Select(GetTupleElement(pred_tuple, 0), tuple12, tuple21); + auto select2 = Select(GetTupleElement(pred_tuple, 1), tuple21, select1); + Add(GetTupleElement(select2, 0), GetTupleElement(select2, 1)); ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); } @@ -395,12 +396,12 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto c1 = builder.ConstantR1(vec1); - auto c2 = builder.ConstantR1(vec2); - auto tuple12 = builder.Tuple({c1, c2}); - auto tuple21 = builder.Tuple({c2, c1}); + auto c1 = ConstantR1(&builder, vec1); + auto c2 = ConstantR1(&builder, vec2); + auto tuple12 = Tuple(&builder, {c1, c2}); + auto tuple21 = Tuple(&builder, {c2, c1}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + Select(ConstantR0(&builder, false), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); @@ -409,9 +410,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { XLA_TEST_F(TupleTest, NestedTuples) { XlaBuilder builder(TestName()); - auto inner_tuple = builder.Tuple( - {builder.ConstantR1({1.0, 2.0}), builder.ConstantR0(42.0)}); - builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); + auto inner_tuple = Tuple(&builder, {ConstantR1(&builder, {1.0, 2.0}), + ConstantR0(&builder, 42.0)}); + Tuple(&builder, {inner_tuple, ConstantR1(&builder, {22.0, 44.0})}); auto expected_v1 = Literal::CreateR1({1.0, 2.0}); auto expected_s = Literal::CreateR0(42.0); @@ -432,10 +433,10 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { Shape outer_tuple_shape = ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape}); - auto input = builder.Parameter(0, outer_tuple_shape, "input"); - auto gte0 = builder.GetTupleElement(input, 0); - auto gte1 = builder.GetTupleElement(gte0, 1); - builder.Add(gte1, builder.ConstantR1({10.0, 11.0, 12.0})); + auto input = Parameter(&builder, 0, outer_tuple_shape, "input"); + auto gte0 = GetTupleElement(input, 0); + auto gte1 = GetTupleElement(gte0, 1); + Add(gte1, ConstantR1(&builder, {10.0, 11.0, 12.0})); std::unique_ptr data = client_ @@ -463,16 +464,16 @@ XLA_TEST_F(TupleTest, ComplexTuples) { Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); Shape arg0_shape = ShapeUtil::MakeTupleShape( {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); - auto input0 = builder.Parameter(0, arg0_shape, "input0"); - auto t0 = builder.GetTupleElement(input0, 0); - auto t1 = builder.GetTupleElement(input0, 1); - auto t10 = builder.GetTupleElement(t1, 0); - auto t11 = builder.GetTupleElement(t1, 1); - auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); - auto input1 = builder.Parameter(1, c64r1, "input1"); - auto prod = builder.Mul(input1, sum, {1}); - builder.Tuple({builder.Tuple({prod, sum}), - builder.ConstantR0({123, 456})}); + auto input0 = Parameter(&builder, 0, arg0_shape, "input0"); + auto t0 = GetTupleElement(input0, 0); + auto t1 = GetTupleElement(input0, 1); + auto t10 = GetTupleElement(t1, 0); + auto t11 = GetTupleElement(t1, 1); + auto sum = Add(Add(t10, t11, {1}), t0); + auto input1 = Parameter(&builder, 1, c64r1, "input1"); + auto prod = Mul(input1, sum, {1}); + Tuple(&builder, {Tuple(&builder, {prod, sum}), + ConstantR0(&builder, {123, 456})}); } std::unique_ptr arg0 = @@ -532,8 +533,8 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); + *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})), + *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index c3abe22797f5eaa76ced2ad8534bd68c32983e60..929b1ca7fb93c545265bf85fec1ed7dc845405b2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -38,8 +38,8 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsSize0TestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {}); + Abs(arg); if (primitive_util::NativeToPrimitiveType() == C64) { ComputeAndCompareR1(&builder, {}, {}); @@ -51,8 +51,8 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({-2, 25, 0, -123, inf(), -inf()}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {-2, 25, 0, -123, inf(), -inf()}); + Abs(arg); ComputeAndCompareR1(&builder, {2, 25, 0, 123, inf(), inf()}, {}); } @@ -60,9 +60,9 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); - auto sign = builder.Sign(arg); + auto arg = ConstantR1( + &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); + Sign(arg); ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); } @@ -70,10 +70,10 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignAbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({-2, 25, 0, -123}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, abs), arg); + auto arg = ConstantR1(&builder, {-2, 25, 0, -123}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, abs), arg); ComputeAndCompareR1(&builder, {0, 0, 0, 0}, {}); } @@ -92,13 +92,13 @@ int64 UnaryOpTest::inf() { template <> void UnaryOpTest::AbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({{-2, 0}, - {0, 25}, - {0, 0}, - {-0.3f, 0.4f}, - {0, inf()}, - {-inf(), 0}}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {{-2, 0}, + {0, 25}, + {0, 0}, + {-0.3f, 0.4f}, + {0, inf()}, + {-inf(), 0}}); + Abs(arg); std::unique_ptr expected = Literal::CreateR1({2, 25, 0, 0.5, inf(), inf()}); @@ -108,9 +108,10 @@ void UnaryOpTest::AbsTestHelper() { template <> void UnaryOpTest::SignTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( + auto arg = ConstantR1( + &builder, {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); - auto sign = builder.Sign(arg); + Sign(arg); std::unique_ptr expected = Literal::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); @@ -121,10 +122,10 @@ template <> void UnaryOpTest::SignAbsTestHelper() { XlaBuilder builder(TestName()); auto arg = - builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg); + ConstantR1(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, ConvertElementType(abs, C64)), arg); std::unique_ptr expected = Literal::CreateR1({0, 0, 0, 0}); @@ -145,34 +146,31 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1) { XLA_TEST_F(UnaryOpTest, AbsTestR0) { XlaBuilder builder(TestName()); - auto argi = builder.ConstantR0(-5); - auto absi = builder.Abs(argi); - auto argf = builder.ConstantR0(-3.0f); - auto absf = builder.Abs(argf); - auto argf0 = builder.ConstantR0(-0.0f); - auto absf0 = builder.Abs(argf0); - auto argc = builder.ConstantR0({-0.3f, 0.4f}); - auto absc = builder.Abs(argc); - builder.Add(builder.Add(absc, absf0), - builder.Add(absf, builder.ConvertElementType(absi, F32))); + auto argi = ConstantR0(&builder, -5); + auto absi = Abs(argi); + auto argf = ConstantR0(&builder, -3.0f); + auto absf = Abs(argf); + auto argf0 = ConstantR0(&builder, -0.0f); + auto absf0 = Abs(argf0); + auto argc = ConstantR0(&builder, {-0.3f, 0.4f}); + auto absc = Abs(argc); + Add(Add(absc, absf0), Add(absf, ConvertElementType(absi, F32))); ComputeAndCompareR0(&builder, 8.5f, {}); } XLA_TEST_F(UnaryOpTest, SignTestR0) { XlaBuilder builder(TestName()); - auto argi = builder.ConstantR0(-5); - auto sgni = builder.Sign(argi); // -1 - auto argf = builder.ConstantR0(-4.0f); - auto sgnf = builder.Sign(argf); // -1 - auto argf0 = builder.ConstantR0(-0.0f); - auto sgnf0 = builder.Sign(argf0); // 0 - auto argc = builder.ConstantR0({-.3, .4}); - auto sgnc = builder.Sign(argc); // (-.6, .8) - builder.Add(sgnc, builder.ConvertElementType( - builder.Add(builder.Add(sgnf0, sgnf), - builder.ConvertElementType(sgni, F32)), - C64)); + auto argi = ConstantR0(&builder, -5); + auto sgni = Sign(argi); // -1 + auto argf = ConstantR0(&builder, -4.0f); + auto sgnf = Sign(argf); // -1 + auto argf0 = ConstantR0(&builder, -0.0f); + auto sgnf0 = Sign(argf0); // 0 + auto argc = ConstantR0(&builder, {-.3, .4}); + auto sgnc = Sign(argc); // (-.6, .8) + Add(sgnc, ConvertElementType( + Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); std::unique_ptr expected = Literal::CreateR0({-2.6f, 0.8f}); @@ -194,9 +192,9 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {2, 25, 0, 123, std::numeric_limits::max()}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1( + &builder, {2, 25, 0, 123, std::numeric_limits::max()}); + Abs(arg); ComputeAndCompareR1( &builder, {2, 25, 0, 123, std::numeric_limits::max()}, {}); @@ -204,37 +202,37 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {2, 25, 0, 123, std::numeric_limits::max()}); - auto sign = builder.Sign(arg); + auto arg = ConstantR1( + &builder, {2, 25, 0, 123, std::numeric_limits::max()}); + Sign(arg); ComputeAndCompareR1(&builder, {1, 1, 0, 1, 1}, {}); } XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR2({{1.0, -2.0}, {-3.0, 4.0}}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, abs), arg); + auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, abs), arg); ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 1}); - auto rhs = builder.ConstantR1({1, 1}); - builder.ConvertElementType(builder.Eq(lhs, rhs), S32); + auto lhs = ConstantR1(&builder, {0, 1}); + auto rhs = ConstantR1(&builder, {1, 1}); + ConvertElementType(Eq(lhs, rhs), S32); ComputeAndCompareR1(&builder, {0, 1}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 1}); - auto rhs = builder.ConstantR1({1, 1}); - builder.ConvertElementType(builder.Eq(lhs, rhs), F32); + auto lhs = ConstantR1(&builder, {0, 1}); + auto rhs = ConstantR1(&builder, {1, 1}); + ConvertElementType(Eq(lhs, rhs), F32); ComputeAndCompareR1(&builder, {0.0, 1.0}, {}); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 82d301983fc7885ef5c1c1ed05b74fc017bb7727..ea3aba6df1d3fbd492a23b280309322b8524c0bf 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -46,7 +46,7 @@ class VecOpsReduceTest : public ClientLibraryTestBase { {{1.0, 2.0, 3.0}, // } plane 2 in dim 0 {4.0, 5.0, 6.0}}}); // clang-format on - return builder_.ConstantR3FromArray3D(x3d); + return ConstantR3FromArray3D(&builder_, x3d); } XlaBuilder builder_; @@ -56,11 +56,10 @@ class VecOpsReduceTest : public ClientLibraryTestBase { TEST_F(VecOpsReduceTest, AddReduceR1F32) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, -4.2f, {}, errspec_); } @@ -71,10 +70,9 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { std::vector input(3000); std::iota(input.begin(), input.end(), 100.0f); - auto x = builder_.ConstantR1(input); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1(&builder_, input); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); float expected = std::accumulate(input.begin(), input.end(), 0.0f); ComputeAndCompareR0(&builder_, expected, {}, errspec_); @@ -83,11 +81,10 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { TEST_F(VecOpsReduceTest, MaxReduceR1F32) { auto max_reducer = CreateScalarMax(); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 0.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, 2.6f, {}, errspec_); } @@ -95,11 +92,10 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32) { TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) { auto max_reducer = CreateScalarMax(); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0(4.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 4.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, 4.0f, {}, errspec_); } @@ -108,15 +104,14 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); // clang-format off - auto x = builder_.ConstantR2({ + auto x = ConstantR2(&builder_, { {1.0, 2.0, 3.0}, // | dim 0 {4.0, 5.0, 6.0}}); // | // ------ dim 1 ---------- // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); ComputeAndCompareR1(&builder_, {6.0, 15.0}, {}, errspec_); } @@ -125,13 +120,12 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); // clang-format off - auto x = builder_.ConstantR2({ + auto x = ConstantR2(&builder_, { {1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR1(&builder_, {5.0, 7.0, 9.0}, {}, errspec_); } @@ -139,9 +133,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{2}); Array2D expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}}); @@ -151,9 +144,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); Array2D expected_array( {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}}); @@ -164,9 +156,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); Array2D expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}}); @@ -176,9 +167,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1, 2}); ComputeAndCompareR1(&builder_, {21.0, 21.0, 21.0}, {}, errspec_); } @@ -186,9 +176,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 2}); ComputeAndCompareR1(&builder_, {18.0, 45.0}, {}, errspec_); } @@ -196,9 +185,8 @@ XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1}); ComputeAndCompareR1(&builder_, {15.0, 21.0, 27.0}, {}, errspec_); } @@ -206,9 +194,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1, 2}); ComputeAndCompareR0(&builder_, 63.0, {}, errspec_); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 5cce7a2bf82c1a8403536a91e67910f949ef185a..c11df7cdf5a22568e80ce6e00fdbd862e6dcae9b 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -50,9 +50,9 @@ class VecOpsSimpleTest : public ClientLibraryTestBase { XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto exp = builder.Exp(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Exp(x); std::vector expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02, 8.1662, 9.9742, 6.7379e-03, 4.0657e-01, @@ -69,8 +69,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int i = 0; i < count; ++i) { exponents.push_back(i / static_cast(count)); } - auto x = builder.ConstantR1(exponents); - auto exp = builder.Exp(x); + auto x = ConstantR1(&builder, exponents); + Exp(x); std::vector expected; expected.reserve(exponents.size()); @@ -98,8 +98,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); - auto x = builder.ConstantR4FromArray4D(exponents); - auto exp = builder.Exp(x); + auto x = ConstantR4FromArray4D(&builder, exponents); + Exp(x); ComputeAndCompareR4(&builder, expected, {}, ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); @@ -107,9 +107,9 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.Neg(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Neg(x); std::vector expected = {-2.1, 2.6, -2.6, 4.0, -2.1, -2.3, 5.0, 0.9, 2.4, -1.6}; @@ -118,8 +118,8 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); - builder.Neg(x); + auto x = ConstantR1(&builder, {2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); + Neg(x); std::vector expected = {-2, 2, -12, 4, -5, -20, 15, 0, 2, -1}; ComputeAndCompareR1(&builder, expected, {}); @@ -127,9 +127,9 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {0, 1, 42, static_cast(-1), static_cast(-12)}); - builder.Neg(x); + auto x = ConstantR1( + &builder, {0, 1, 42, static_cast(-1), static_cast(-12)}); + Neg(x); std::vector expected = {0, static_cast(-1), static_cast(-42), 1, 12}; ComputeAndCompareR1(&builder, expected, {}); @@ -137,9 +137,9 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.SquareF32(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + SquareF32(x); std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, 5.29, 25., 0.81, 5.76, 2.56}; @@ -148,9 +148,9 @@ XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.ReciprocalF32(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + ReciprocalF32(x); std::vector expected = { 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, @@ -160,16 +160,16 @@ XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0, -0.0}); - auto exp = builder.SqrtF32(x); + auto x = ConstantR1(&builder, {0.0, -0.0}); + SqrtF32(x); ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); } XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); - auto exp = builder.SqrtF32(x); + auto x = ConstantR1(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); + SqrtF32(x); std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -177,9 +177,9 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { XlaBuilder builder(TestName()); - auto x = - builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); - auto exp = builder.Pow(x, builder.ConstantR0(-.5f)); + auto x = ConstantR1(&builder, + {16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); + Pow(x, ConstantR0(&builder, -.5f)); std::vector expected = {.25, 1, .03125, 2.5, 2.23607, .009000, .900025}; @@ -191,11 +191,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Map({x, y}, add, {0}); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Map(&builder, {x, y}, add, {0}); std::vector expected = {1.7, -3.2, -0.4, -3.8, 5.9, 0.1, -6.8, 4., -1., 2.2}; @@ -204,11 +204,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Max(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Max(x, y); std::vector expected = {2.1, -0.6, 2.6, 0.2, 3.8, 2.3, -1.8, 4.9, 1.4, 1.6}; @@ -227,7 +227,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + Max(v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -267,7 +267,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + Max(v1, v2); ComputeAndCompareR1(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -275,10 +275,10 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR0(0); - auto max = builder.Max(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR0(&builder, 0); + Max(x, y); std::vector expected = {2.1, 0.0, 2.6, 0.0, 2.1, 2.3, 0.0, 0.0, 0.0, 1.6}; @@ -287,11 +287,11 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto min = builder.Min(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Min(x, y); std::vector expected = {-0.4, -2.6, -3.0, -4.0, 2.1, -2.2, -5.0, -0.9, -2.4, 0.6}; @@ -300,11 +300,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(1); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Min(builder.Max(x, zero), one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Min(Max(x, zero), one); std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -313,11 +313,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(1); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Clamp(zero, x, one); std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -326,10 +326,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR1({0.0f, 0.0f}); - auto one = builder.ConstantR1({1.0f, 1.0f}); - auto x = builder.ConstantR1({2.1, -2.6}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR1(&builder, {0.0f, 0.0f}); + auto one = ConstantR1(&builder, {1.0f, 1.0f}); + auto x = ConstantR1(&builder, {2.1, -2.6}); + Clamp(zero, x, one); std::vector expected = {1.0, 0.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -337,11 +337,11 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XlaBuilder builder(TestName()); - auto one = builder.ConstantR0(1); - auto two = builder.ConstantR0(2); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(one, x, two); + auto one = ConstantR0(&builder, 1); + auto two = ConstantR0(&builder, 2); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Clamp(one, x, two); std::vector expected = {2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0}; @@ -350,10 +350,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(10); - auto x = builder.ConstantR1({-3, 3, 9, 13}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 10); + auto x = ConstantR1(&builder, {-3, 3, 9, 13}); + Clamp(zero, x, one); std::vector expected = {0, 3, 9, 10}; ComputeAndCompareR1(&builder, expected, {}); @@ -365,9 +365,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // add_half(x) = x + 0.5 XlaBuilder builder("add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto half = builder.ConstantR0(0.5); - builder.Add(x_value, half); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x_value"); + auto half = ConstantR0(&builder, 0.5); + Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); @@ -378,9 +378,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // clamp(y) = clamp<0,5>(y) XlaBuilder builder("clamp"); auto y_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value"); - auto zero = builder.ConstantR0(0.0); - auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0(5)); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "y_value"); + auto zero = ConstantR0(&builder, 0.0); + Clamp(zero, y_value, ConstantR0(&builder, 5)); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); clamp = computation_status.ConsumeValueOrDie(); @@ -391,13 +391,13 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // mult_relu_add(z) = clamp(add_half(2 * max(z, 0))) XlaBuilder builder("mult_relu_add"); auto z_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); - auto zero = builder.ConstantR0(0.0); - auto two = builder.ConstantR0(2.0); - auto max = builder.Max(z_value, zero); - auto mult = builder.Mul(two, max); - auto inner = builder.Map({mult}, add_half, {}); - builder.Map({inner}, clamp, {}); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "z_value"); + auto zero = ConstantR0(&builder, 0.0); + auto two = ConstantR0(&builder, 2.0); + auto max = Max(z_value, zero); + auto mult = Mul(two, max); + auto inner = Map(&builder, {mult}, add_half, {}); + Map(&builder, {inner}, clamp, {}); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); mult_relu_add = computation_status.ConsumeValueOrDie(); @@ -405,9 +405,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { XlaBuilder builder("map10"); { - auto x = builder.ConstantR1( - {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto activations = builder.Map({x}, mult_relu_add, {0}); + auto x = ConstantR1( + &builder, {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Map(&builder, {x}, mult_relu_add, {0}); } std::vector expected = {4.7, 0.5, 5.0, 0.5, 4.7, @@ -417,9 +417,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto y = builder.ConstantR0(3); - builder.Rem(x, y); + auto x = ConstantR1(&builder, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto y = ConstantR0(&builder, 3); + Rem(x, y); std::vector expected = {-2, -1, 0, -2, -1, 0, 1, 2, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); @@ -427,9 +427,9 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({false, true}); - auto y = builder.ConstantR1({true, false}); - builder.Eq(x, y); + auto x = ConstantR1(&builder, {false, true}); + auto y = ConstantR1(&builder, {true, false}); + Eq(x, y); std::array expected = {{false, false}}; ComputeAndCompareR1(&builder, expected, {}); @@ -437,9 +437,9 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({false, true}); - auto y = builder.ConstantR1({true, false}); - builder.Ne(x, y); + auto x = ConstantR1(&builder, {false, true}); + auto y = ConstantR1(&builder, {true, false}); + Ne(x, y); std::array expected = {{true, true}}; ComputeAndCompareR1(&builder, expected, {}); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index c463f3eac55e5b8ab32dc52d5a38e7840241bc58..bbd67cd8d7c433550deefc38ce28b2b732d354aa 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -55,8 +55,8 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -64,16 +64,16 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -91,8 +91,8 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -100,16 +100,16 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -122,8 +122,8 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -131,18 +131,18 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.Reduce(builder.ConstantR1(2, 1), - builder.ConstantR0(0), - CreateScalarAddComputation(S32, &builder), {0}); - builder.While(condition, body, init); + auto init = + Reduce(ConstantR1(&builder, 2, 1), ConstantR0(&builder, 0), + CreateScalarAddComputation(S32, &builder), {0}); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -154,8 +154,8 @@ TEST_F(WhileTest, WhileWithPredicateResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Ne(builder.ConstantR0(true), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Ne(ConstantR0(&builder, true), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -163,16 +163,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Or(prev, builder.ConstantR0(true)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Or(prev, ConstantR0(&builder, true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.Ne(builder.ConstantR0(false), - builder.ConstantR0(true)); - builder.While(condition, body, init); + auto init = + Ne(ConstantR0(&builder, false), ConstantR0(&builder, true)); + While(condition, body, init); ComputeAndCompareR0(&builder, true, {}); } @@ -184,17 +184,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) { // while (result.sum() < 15.5f) { // result = result + vector(0); // } -// TODO(b/29185393): does not terminate on CPU. -TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -203,10 +202,10 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -215,16 +214,16 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1({}); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, {}); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1({}); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, {}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -246,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -257,10 +256,10 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -269,16 +268,16 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1(8, 0.125f); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, 8, 0.125f); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1(8, 0.f); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, 8, 0.f); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -317,10 +316,10 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -329,20 +328,20 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1(8, 0.125f); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, 8, 0.125f); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1(8, 0.f); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, 8, 0.f); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - builder.Tuple({result}); + Tuple(&builder, {result}); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements @@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(N), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, N), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -377,22 +376,23 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto w1 = builder.GetTupleElement(prev, 1); - auto w2 = builder.GetTupleElement(prev, 2); - auto w3 = builder.GetTupleElement(prev, 3); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto w1 = GetTupleElement(prev, 1); + auto w2 = GetTupleElement(prev, 2); + auto w3 = GetTupleElement(prev, 3); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), - builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 3, 1.f), + ConstantR1(&builder, 3, 2.f), + ConstantR1(&builder, 3, 3.f)}); + auto result = While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -419,9 +419,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(N), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, N), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -430,26 +430,27 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto w1 = builder.GetTupleElement(prev, 1); - auto w2 = builder.GetTupleElement(prev, 2); - auto w3 = builder.GetTupleElement(prev, 3); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto w1 = GetTupleElement(prev, 1); + auto w2 = GetTupleElement(prev, 2); + auto w3 = GetTupleElement(prev, 3); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), - builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); - auto xla_while = builder.While(condition, body, init); - - auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), - builder.GetTupleElement(xla_while, 2)); - auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 3, 1.f), + ConstantR1(&builder, 3, 2.f), + ConstantR1(&builder, 3, 3.f)}); + auto xla_while = While(condition, body, init); + + auto add12 = + Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2)); + auto result = Add(add12, GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -474,9 +475,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -486,21 +487,21 @@ TEST_F(WhileTest, WhileWithTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -524,9 +525,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -535,21 +536,20 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto pred = builder.GetTupleElement(prev, 1); - auto new_pred = builder.Or(pred, builder.ConstantR0(true)); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto pred = GetTupleElement(prev, 1); + auto new_pred = Or(pred, ConstantR0(&builder, true)); + Tuple(&builder, {Add(iteration, ConstantR0(&builder, 1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple({builder.ConstantR0(0), - builder.Ne(builder.ConstantR0(false), - builder.ConstantR0(true))}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + Ne(ConstantR0(&builder, false), + ConstantR0(&builder, true))}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -571,9 +571,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -583,18 +583,18 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), - builder.ConstantR0(7)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Tuple(&builder, {Add(iteration, ConstantR0(&builder, 1)), + ConstantR0(&builder, 7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR0(7)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR0(&builder, 7)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -632,9 +632,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -642,9 +642,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -654,43 +654,43 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaComputation body2; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); - auto while2 = builder.While(condition2, body2, while1); + auto while2 = While(condition2, body2, while1); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -711,9 +711,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -721,9 +721,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -733,30 +733,30 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); - auto while2 = builder.While(condition2, body, while1); + auto while2 = While(condition2, body, while1); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -778,9 +778,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -788,9 +788,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -800,29 +800,29 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); - auto while2 = builder.While(condition2, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); + auto while2 = While(condition2, body, init); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -844,9 +844,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -856,29 +856,28 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); + auto prev = Parameter(&builder, 0, result_shape, "prev"); // TupleElement 0 - auto iteration = builder.GetTupleElement(prev, 0); - auto out0 = builder.Add(iteration, builder.ConstantR0(1)); + auto iteration = GetTupleElement(prev, 0); + auto out0 = Add(iteration, ConstantR0(&builder, 1)); // TupleElement 1 - auto input = builder.GetTupleElement(prev, 1); + auto input = GetTupleElement(prev, 1); // Update. - auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32); + auto update = ConvertElementType(Broadcast(out0, {2}), F32); // Starts = iteration * 2; - auto starts = builder.Reshape( - builder.Mul(iteration, builder.ConstantR0(2)), {1}); + auto starts = Reshape(Mul(iteration, ConstantR0(&builder, 2)), {1}); // UpdateSlice. - auto out1 = builder.DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, starts); - builder.Tuple({out0, out1}); + Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -913,10 +912,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { XlaBuilder builder(TestName()); - auto prev = builder.Reshape( - builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); - builder.Gt(builder.ConstantR0(count), prev); + auto prev = Reshape( + Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {}); + Gt(ConstantR0(&builder, count), prev); return builder.Build().ConsumeValueOrDie(); }; @@ -924,22 +922,22 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, v6s32, "prev"); - auto inc = builder.ConcatInDim( - {builder.ConstantR1({1}), - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(100), - ShapeUtil::MakeShape(S32, {5}))}, - 0); - builder.Add(inc, prev); + auto prev = Parameter(&builder, 0, v6s32, "prev"); + auto inc = ConcatInDim(&builder, + {ConstantR1(&builder, {1}), + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {5}))}, + 0); + Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { XlaBuilder builder(TestName()); - auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); - builder.While(build_condition(count), body, init); + auto init = ConstantR1(&builder, {0, 0, 0, 0, 0, 0}); + While(build_condition(count), body, init); return builder.Build(); }; @@ -958,26 +956,23 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); - auto t = outer.Tuple({p, outer.ConstantR1({1, 1})}); + auto p = Parameter(&outer, 0, element_shape, "param"); + auto t = Tuple(&outer, {p, ConstantR1(&outer, {1, 1})}); TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t)); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, tuple_shape, "t"); - TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), - cond.ConstantR1({42, 42})), - &cond) - .status()); + auto cond_t = Parameter(&cond, 0, tuple_shape, "t"); + Any(Eq(GetTupleElement(cond_t, 0), ConstantR1(&cond, {42, 42}))); XlaBuilder body("body"); - auto body_t = body.Parameter(0, tuple_shape, "t"); - auto e = body.GetTupleElement(body_t, 1); - body.Tuple({e, e}); + auto body_t = Parameter(&body, 0, tuple_shape, "t"); + auto e = GetTupleElement(body_t, 1); + Tuple(&body, {e, e}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, t); + While(cond_computation, body_computation, t); auto expected_element = Literal::CreateR1({1, 1}); auto expected = @@ -993,20 +988,19 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); + auto p = Parameter(&outer, 0, element_shape, "param"); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, element_shape, "t"); - TF_ASSERT_OK( - Any(cond.Eq(cond_t, cond.ConstantR1({42, 42})), &cond).status()); + auto cond_t = Parameter(&cond, 0, element_shape, "t"); + Any(Eq(cond_t, ConstantR1(&cond, {42, 42}))); XlaBuilder body("body"); - auto body_t = body.Parameter(0, element_shape, "t"); - auto e = body.Broadcast(body.ConstantR0(1.0), {2}); + Parameter(&body, 0, element_shape, "t"); + Broadcast(ConstantR0(&body, 1.0), {2}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, @@ -1019,21 +1013,20 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); + auto p = Parameter(&outer, 0, element_shape, "param"); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, element_shape, "t"); - cond.Eq(cond_t, cond.ConstantR0(42)); + auto cond_t = Parameter(&cond, 0, element_shape, "t"); + Eq(cond_t, ConstantR0(&cond, 42)); XlaBuilder body("body"); - auto body_t = body.Parameter(0, element_shape, "t"); - auto tuple = - body.Tuple({body_t, body.Add(body_t, body.ConstantR0(1))}); - auto e = body.GetTupleElement(tuple, 1); + auto body_t = Parameter(&body, 0, element_shape, "t"); + auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0(&body, 1))}); + GetTupleElement(tuple, 1); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, @@ -1056,25 +1049,23 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { XlaBuilder outer("outer"); auto p = - outer.Tuple({outer.ConstantR0(0), - outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); + Tuple(&outer, {ConstantR0(&outer, 0), + Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")}); XlaBuilder cond("cond"); - auto params = cond.Parameter(0, result_shape, "prev"); - auto cond_t = cond.Add(cond.GetTupleElement(params, 1), - cond.GetTupleElement(params, 0)); - cond.Lt(cond_t, cond.ConstantR0(30)); + auto params = Parameter(&cond, 0, result_shape, "prev"); + auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0)); + Lt(cond_t, ConstantR0(&cond, 30)); XlaBuilder body("body"); - auto body_t = body.Parameter(0, result_shape, "t"); + auto body_t = Parameter(&body, 0, result_shape, "t"); - auto tuple = body.Tuple( - {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0(1)), - body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0(1))}); + Tuple(&body, {Add(GetTupleElement(body_t, 0), ConstantR0(&body, 1)), + Add(GetTupleElement(body_t, 1), ConstantR0(&body, 1))}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, @@ -1105,9 +1096,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation inner_condition; { XlaBuilder builder("inner_condition"); - auto params = builder.Parameter(0, inner_result_shape, "prev"); - auto i = builder.GetTupleElement(params, 0); - builder.Lt(i, builder.ConstantR0(7)); + auto params = Parameter(&builder, 0, inner_result_shape, "prev"); + auto i = GetTupleElement(params, 0); + Lt(i, ConstantR0(&builder, 7)); inner_condition = builder.Build().ConsumeValueOrDie(); } @@ -1116,8 +1107,8 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation outer_condition; { XlaBuilder builder("outer_condition"); - auto prev = builder.Parameter(0, outer_result_shape, "prev"); - builder.Lt(prev, builder.ConstantR0(30)); + auto prev = Parameter(&builder, 0, outer_result_shape, "prev"); + Lt(prev, ConstantR0(&builder, 30)); outer_condition = builder.Build().ConsumeValueOrDie(); } @@ -1126,12 +1117,12 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation inner_body; { XlaBuilder builder("inner_body"); - auto params = builder.Parameter(0, inner_result_shape, "prev"); - auto i = builder.GetTupleElement(params, 0); - auto result = builder.GetTupleElement(params, 1); - i = builder.Add(builder.ConstantR0(1), i); - result = builder.Add(builder.ConstantR0(2), result); - builder.Tuple({i, result}); + auto params = Parameter(&builder, 0, inner_result_shape, "prev"); + auto i = GetTupleElement(params, 0); + auto result = GetTupleElement(params, 1); + i = Add(ConstantR0(&builder, 1), i); + result = Add(ConstantR0(&builder, 2), result); + Tuple(&builder, {i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } @@ -1139,17 +1130,17 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation outer_body; { XlaBuilder builder("outer_body"); - auto prev = builder.Parameter(0, outer_result_shape, "prev"); - auto init = builder.Tuple({builder.ConstantR0(0), prev}); - auto result = builder.While(inner_condition, inner_body, init); - builder.GetTupleElement(result, 1); + auto prev = Parameter(&builder, 0, outer_result_shape, "prev"); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), prev}); + auto result = While(inner_condition, inner_body, init); + GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(outer_condition, outer_body, init); + auto init = ConstantR0(&builder, 0); + While(outer_condition, outer_body, init); ComputeAndCompareR0(&builder, 42, {}); } @@ -1167,8 +1158,8 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation condition_callee; { XlaBuilder builder("condition_callee"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Tuple({builder.Gt(builder.ConstantR0(5), prev)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Tuple(&builder, {Gt(ConstantR0(&builder, 5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } @@ -1176,9 +1167,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Call(condition_callee, {prev}); - builder.GetTupleElement(result, 0); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto result = Call(&builder, condition_callee, {prev}); + GetTupleElement(result, 0); condition = builder.Build().ConsumeValueOrDie(); } @@ -1186,16 +1177,16 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -1210,30 +1201,30 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { XlaComputation condition; { XlaBuilder builder("condition"); - auto state = builder.Parameter(0, while_shape, "state"); - builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); + auto state = Parameter(&builder, 0, while_shape, "state"); + Gt(ConstantR0(&builder, 5), GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } XlaComputation body; { XlaBuilder builder("body"); - auto state = builder.Parameter(0, while_shape, "state"); - auto indvar = builder.GetTupleElement(state, 0); - auto input_0 = builder.GetTupleElement(state, 1); - auto input_1 = builder.GetTupleElement(state, 2); - auto output = builder.Tanh(builder.Dot(input_0, input_1)); - auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); - builder.Tuple({indvar_next, input_0, input_1, output}); + auto state = Parameter(&builder, 0, while_shape, "state"); + auto indvar = GetTupleElement(state, 0); + auto input_0 = GetTupleElement(state, 1); + auto input_1 = GetTupleElement(state, 2); + auto output = Tanh(Dot(input_0, input_1)); + auto indvar_next = Add(indvar, ConstantR0(&builder, 1)); + Tuple(&builder, {indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaBuilder builder(TestName()); - auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); - auto init = builder.Tuple( - {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); - auto while_instruction = builder.While(condition, body, init); - builder.GetTupleElement(while_instruction, 3); + auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix"); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), matrix_input, + matrix_input, matrix_input}); + auto while_instruction = While(condition, body, init); + GetTupleElement(while_instruction, 3); TF_ASSERT_OK_AND_ASSIGN(auto param_value, client_->TransferToServer(*Literal::CreateR2( @@ -1264,9 +1255,9 @@ void BM_WhileLoop(int num_iters) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, loop_state_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(loop_limit)); + auto prev = Parameter(&builder, 0, loop_state_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, loop_limit)); condition = builder.Build().ConsumeValueOrDie(); } @@ -1274,29 +1265,29 @@ void BM_WhileLoop(int num_iters) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, loop_state_shape, "prev"); + auto prev = Parameter(&builder, 0, loop_state_shape, "prev"); // TupleElement 0 - auto iteration = builder.GetTupleElement(prev, 0); - auto out0 = builder.Add(iteration, builder.ConstantR0(1)); + auto iteration = GetTupleElement(prev, 0); + auto out0 = Add(iteration, ConstantR0(&builder, 1)); // TupleElement 1 - auto input = builder.GetTupleElement(prev, 1); + auto input = GetTupleElement(prev, 1); // Update. - auto one = builder.ConstantR0(1.0); - auto update = builder.Broadcast(one, {1, 1024, 1024}); + auto one = ConstantR0(&builder, 1.0); + auto update = Broadcast(one, {1, 1024, 1024}); // Starts = iteration * 2; - auto starts = builder.ConstantR1({0, 0, 0}); + auto starts = ConstantR1(&builder, {0, 0, 0}); // UpdateSlice. - auto out1 = builder.DynamicUpdateSlice(input, update, starts); - builder.Tuple({out0, out1}); + auto out1 = DynamicUpdateSlice(input, update, starts); + Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. XlaBuilder builder("while"); - auto zero = builder.ConstantR0(0.0); - auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); - auto init = builder.Tuple({builder.ConstantR0(0), input}); - builder.While(condition, body, init); + auto zero = ConstantR0(&builder, 0.0); + auto input = Broadcast(zero, {seq_len, 1024, 1024}); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), input}); + While(condition, body, init); auto computation = builder.Build().ConsumeValueOrDie(); std::unique_ptr executable = diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 3c9a01653c67203cbc962a3d3d967142f7a2102c..c0616809f9f060e3447e62d387535a5acffe1075 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -128,20 +128,23 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, se::StreamExecutor* executor = backend->default_stream_executor(); DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); + TF_ASSERT_OK_AND_ASSIGN( + Backend::StreamPtr stream_ptr, + backend->BorrowStream(backend->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer lhs_arg, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, @@ -153,9 +156,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, &executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); - TF_ASSERT_OK_AND_ASSIGN( - Backend::StreamPtr stream_ptr, - backend->BorrowStream(backend->default_device_ordinal())); ExecutableRunOptions exec_run_options; exec_run_options.set_stream(stream_ptr.get()); exec_run_options.set_allocator(backend->memory_allocator()); @@ -168,6 +168,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, auto execution_result, executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, &hlo_execution_profile)); + TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; *profile_output = @@ -187,9 +188,9 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { ClientLibrary::GetOrCreateLocalClient(platform)); XlaBuilder builder(TestName()); - auto result = builder.Tanh(builder.Add( - builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); + Tanh(Add( + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -255,30 +256,30 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { XlaComputation condition; { XlaBuilder builder("condition"); - auto state = builder.Parameter(0, while_result_shape, "state"); - auto iteration = builder.GetTupleElement(state, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto state = Parameter(&builder, 0, while_result_shape, "state"); + auto iteration = GetTupleElement(state, 0); + Gt(ConstantR0(&builder, 5), iteration); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } XlaComputation body; { XlaBuilder builder("body"); - auto state = builder.Parameter(0, while_result_shape, "state"); - auto matrix = builder.GetTupleElement(state, 1); - auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), - builder.ConstantR0(1)); - builder.Tuple({next_iteration, builder.Add(matrix, matrix)}); + auto state = Parameter(&builder, 0, while_result_shape, "state"); + auto matrix = GetTupleElement(state, 1); + auto next_iteration = + Add(GetTupleElement(state, 0), ConstantR0(&builder, 1)); + Tuple(&builder, {next_iteration, Add(matrix, matrix)}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaBuilder builder(TestName()); auto initial_while_state = - builder.Tuple({builder.ConstantR0(0), - builder.Parameter(0, matrix_shape, "initial_value")}); - auto while_result = builder.While(condition, body, initial_while_state); - builder.Add(builder.GetTupleElement(while_result, 1), - builder.Parameter(1, matrix_shape, "other_value")); + Tuple(&builder, {ConstantR0(&builder, 0), + Parameter(&builder, 0, matrix_shape, "initial_value")}); + auto while_result = While(condition, body, initial_while_state); + Add(GetTupleElement(while_result, 1), + Parameter(&builder, 1, matrix_shape, "other_value")); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a9f2915b458b1816926de727b3da21982d06f6c0..a075195618c42aaa11f7b1c17730e67889a2c308 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -49,6 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { } // Unfortunately Google's internal benchmark infrastructure has a // different API than Tensorflow's. + testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) base::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index f7574e0b1cc95daee6d6743ba4e2e490ee87e7c6..3a7917cf3043de8a77f189f011bdeb3e8d2ddf3c 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -174,6 +174,11 @@ StatusOr ReplayComputation(const HloSnapshot& module, client->Compile(computation, argument_layouts, ExecutableBuildOptions()) .ValueOrDie(); + // Do not attmept to run the executable, if num_runs is less than 1. + if (opts.num_runs < 1) { + return Cancelled("Cancelled after compilation since --num_runs < 1."); + } + // Run the computation num_runs times, and return the result from the last // execution. StreamExecutorMemoryAllocator allocator( @@ -191,9 +196,6 @@ StatusOr ReplayComputation(const HloSnapshot& module, << static_cast(profile.compute_time_ns()) / 1e9 << "s"; } - // Check that --num_runs > 0, otherwise *result below will fail with an - // unhelpful error (because the loop didn't run any iterations). - CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0"; TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, client->ShapedBufferToLiteral(*result)); return std::move(*result_literal); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index b4f45cc972d3d397ddff8e8d9163d1fef387392f..6041fae1595dacb309008857f1c758ee96a646bb 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -539,6 +540,11 @@ int64 FindIndex(const C& c, Value&& value) { return std::distance(c.begin(), it); } +template +bool ArrayContains(tensorflow::gtl::ArraySlice c, const T& value) { + return c_find(c, value) != c.end(); +} + template void InsertAt(C* c, int64 index, Value&& value) { c->insert(c->begin() + index, std::forward(value)); @@ -549,6 +555,12 @@ void EraseAt(C* c, int64 index) { c->erase(c->begin() + index); } +template +std::vector InlinedVectorToVector( + const tensorflow::gtl::InlinedVector& inlined_vector) { + return std::vector(inlined_vector.begin(), inlined_vector.end()); +} + // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 0af73e8a93060f4569ddef9697b89a6fa2b8674b..c7472173a705b7a6e1bee2f5221f23db0a77991d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -274,6 +274,9 @@ message ExecutionProfile { // for the input data transfer since the memory is initialized with the proper // values before the execution. int64 compute_and_transfer_time_ns = 5; + + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; } // Handle given to a user that represents an execution that the user launched diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 50b1ae5cc3cba2d6ac89c4415a3419ffdf7aec93..e2c85f39957f9b43ca238cd59c17bc86671ec20f 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -9,6 +9,7 @@ load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow:tensorflow.bzl", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") py_library( name = "contrib_py", @@ -26,14 +27,12 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/checkpoint/python:checkpoint", - "//tensorflow/contrib/cloud:cloud_py", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", - "//tensorflow/contrib/control_flow", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", @@ -46,7 +45,6 @@ py_library( "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/feature_column:feature_column_py", "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/gan", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/contrib/grid_rnn:grid_rnn_py", @@ -115,6 +113,7 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", + "//tensorflow/python/estimator:estimator_py", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", ]) + select({ @@ -123,7 +122,11 @@ py_library( "//tensorflow/contrib/kafka", ], "//conditions:default": [], - }) + if_not_windows([ + }) + if_not_windows_cuda([ + "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols + ]) + if_not_windows([ + "//tensorflow/contrib/bigtable", + "//tensorflow/contrib/cloud:cloud_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index ad8c40395c2cdcc5e4288e04bb2115bd3627cdc9..9aad772f0acd941d50d6ba238d345616195a6939 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -30,7 +30,6 @@ from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import constrained_optimization -from tensorflow.contrib import control_flow from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index c10179ba8b290b6209f5567d6323df4bcf711585..f0b1c92cf7e4b760381da38febd9682ce2a4f27c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -1,6 +1,8 @@ # Description: # JNI-based Java inference interface for TensorFlow. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD index 30dd846893c30b9205972bd5216cc1871ab03d76..ad700ac4a0342e2a7bc07a6ecf6710cea892e296 100644 --- a/tensorflow/contrib/autograph/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -23,9 +23,9 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/utils", - "@gast_archive//:gast", - "@six_archive//:six", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md index a4aec8c74a9ad1418072471a5d3cde8c3b968a38..06fb7b03d5dbbfd2fcb6d6a2ecfe5c817f94a469 100644 --- a/tensorflow/contrib/autograph/CONTRIBUTING.md +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# How to Contribute +# How to contribute We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. @@ -46,3 +46,50 @@ bazel test --config=opt --copt=-O3 --copt=-march=native \ ``` from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) + +## Developer info + +### Module structure + +The graph below describes the dependencies between AutoGraph modules (not to be mistaken with the directory structure for these modules, which is flat): + +```dot +digraph d_modules { + autograph [style=filled]; + converters; + core; + impl; + lang; + operators; + + autograph -> impl + autograph -> lang + + impl -> converters + impl -> core + impl -> operators + + lang -> operators + + converters -> core + converters -> lang +} +``` + +`autograph` is the sole user-visible module. + +A short description of the modules: + + * `autograph`: the main module imported by the user and by the generated code; only contains declarations + * `impl`: high level code and the implementation of the api frontend + * `core`: base classes for the AutoGraph source code transformation logic; see in particular `converter.py` + * `lang`: special user-visible functions that serve as extensions to the Python language + * `converters`: collection of source code transformation modules specialized for particular AutoGraph features + * `operators`: collection of operators that AutoGraph overloads; these correspond to Python operators as well as Python syntactic structures, like control flow + +There are two additional modules, `pyct` and `utils`. These are independent of AutoGraph: + + * `pyct`: a general purpose Python source code transformation library + * `utils`: the kitchen sync; deprecated + +Note: we have a long term plan to factor out an implementation of `impl` and `converters` that is independent of autograph, into a general purpose Python operator overloading library. diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 829a57d8e61ee4a41076f7397488cd85bdca1376..7e26f4711851138c1834f881621ebfa227a85821 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -4,7 +4,7 @@ IMPORTANT: AutoGraph is alpha software, and under active development. Expect rou AutoGraph is a Python to TensorFlow compiler. -With AutoGraph, you can write [Eager style](https://www.tensorflow.org/programmers_guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. +With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. For example, this Python function: diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index dbdbad8f4c91c725294baa36acebbaf5b5e8cf5c..361cf2d77c7e46912d5bff5881df2ffa897c5179 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -30,9 +30,9 @@ from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph -from tensorflow.contrib.autograph.impl.directives import set_element_type -from tensorflow.contrib.autograph.impl.directives import set_loop_options -from tensorflow.contrib.autograph.impl.special_functions import stack +from tensorflow.contrib.autograph.lang.directives import set_element_type +from tensorflow.contrib.autograph.lang.directives import set_loop_options +from tensorflow.contrib.autograph.lang.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented @@ -46,7 +46,7 @@ _allowed_symbols = [ 'to_graph', # Overloaded operators 'operators', - # Special functions and directives + # Python language "extensions" 'set_element_type', 'set_loop_options', 'stack', diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 284ad84be566199adaaa1ab641d37528ae4dfd2d..b2e2e27673dafe290cef40a9fe0a834bfe1ea61f 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -36,25 +36,12 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "@gast_archive//:gast", - ], -) - -py_library( - name = "test_lib", - srcs = [ - "converter_test_base.py", - ], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":converters", - "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/core", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", - "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:util", "@gast_archive//:gast", - "@six_archive//:six", ], ) @@ -64,7 +51,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -74,7 +62,8 @@ py_test( srcs = ["break_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -85,7 +74,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -97,7 +87,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], @@ -108,7 +99,8 @@ py_test( srcs = ["continue_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -118,7 +110,8 @@ py_test( srcs = ["control_flow_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -127,8 +120,13 @@ py_test( name = "decorators_test", srcs = ["decorators_test.py"], srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", + ], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -137,7 +135,8 @@ py_test( name = "name_scopes_test", srcs = ["name_scopes_test.py"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -148,7 +147,8 @@ py_test( srcs = ["list_comprehension_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -158,7 +158,8 @@ py_test( srcs = ["lists_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -168,7 +169,8 @@ py_test( srcs = ["logical_expressions_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -183,7 +185,8 @@ py_test( "notap", ], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -193,7 +196,8 @@ py_test( srcs = ["single_return_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -204,7 +208,8 @@ py_test( srcs = ["ifexp_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -215,7 +220,8 @@ py_test( srcs = ["slices_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index 3b0db677ce5e417e7afea8d8fe4121a0352bb6d7..e664a403a5fb800e7d0dddfa5695330927aaf4e0 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class AssertsTransformer(transformer.Base): +class AssertsTransformer(converter.Base): """Transforms Print nodes to Call so they can be handled as functions.""" def visit_Assert(self, node): @@ -45,5 +45,5 @@ class AssertsTransformer(transformer.Base): raise NotImplementedError('can only convert string messages for now.') -def transform(node, context): - return AssertsTransformer(context).visit(node) +def transform(node, ctx): + return AssertsTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py index cc913febe8d0f411588af69b87ec52ce58f4469c..2cd0e626bc4552bd40bc94b890fdcc7efcafb3f3 100644 --- a/tensorflow/contrib/autograph/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -21,11 +21,11 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.converters import asserts -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class AssertsTest(converter_test_base.TestCase): +class AssertsTest(converter_testing.TestCase): def test_transform(self): diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 775d92c1d9f8bc35d1eda62f3f3ef7ee43414779..a990e359a2a25a57ee2a4f8a866350633f3b9ea8 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -29,7 +29,7 @@ BREAK_USED = 'break_used' CONTROL_VAR_NAME = 'control_var_name' -class BreakStatementTransformer(transformer.Base): +class BreakStatementTransformer(converter.Base): """Canonicalizes break statements into additional conditionals.""" def visit_Break(self, node): @@ -67,7 +67,7 @@ class BreakStatementTransformer(transformer.Base): def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break_', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) @@ -97,7 +97,7 @@ class BreakStatementTransformer(transformer.Base): def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break_', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) @@ -137,5 +137,5 @@ class BreakStatementTransformer(transformer.Base): return node -def transform(node, context): - return BreakStatementTransformer(context).visit(node) +def transform(node, ctx): + return BreakStatementTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index 1af59e9b5260fe0d3a3ef72c7a003dc451e230f3..dcff1c54c2f9300d58d217517e108d634ae85fb4 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class BreakCanonicalizationTest(converter_test_base.TestCase): +class BreakCanonicalizationTest(converter_testing.TestCase): def test_basic_while(self): diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 231e4ee35a72f51845a476d9f605986ac73b4676..b26c52294c2d1c11ce14d8a2903f7f88079a703f 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class BuiltinFunctionTransformer(transformer.Base): +class BuiltinFunctionTransformer(converter.Base): """Handles builtin functions. This transformer only covers functions that are translated into a @@ -68,5 +68,5 @@ class BuiltinFunctionTransformer(transformer.Base): return self.visit(function_call) -def transform(node, context): - return BuiltinFunctionTransformer(context).visit(node) +def transform(node, ctx): + return BuiltinFunctionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index 30272409df322560b04ba75b3e1cb6f9ad5ff0af..e9000e518ce14f9e0ea486d5b3e374439b8c78ca 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -23,13 +23,13 @@ import sys import six from tensorflow.contrib.autograph.converters import builtin_functions -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class BuiltinFunctionsTest(converter_test_base.TestCase): +class BuiltinFunctionsTest(converter_testing.TestCase): def test_len(self): diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index b6ecdcb7809b1ad7e7461324cb6a110ef4180609..a36b3d77a9233daed864c616306b2ad27f582a38 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -26,12 +26,12 @@ from collections import namedtuple import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -45,6 +45,9 @@ KNOWN_NUMPY_FUNCTIONS = { } +# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer. + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -76,20 +79,18 @@ class FunctionNamer(object): raise NotImplementedError() -class CallTreeTransformer(transformer.Base): - """Transforms the call tree by renaming transformed symbols.""" +# TODO(mdan): Rename to CallsTransformer. - def __init__(self, context, uncompiled_modules, nocompile_decorators): - super(CallTreeTransformer, self).__init__(context) - self.uncompiled_modules = uncompiled_modules - self.nocompile_decorators = nocompile_decorators + +class CallTreeTransformer(converter.Base): + """Transforms the call tree by renaming transformed symbols.""" def _resolve_name(self, node): """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): - return self.context.namespace.get(node.id) + return self.ctx.namespace.get(node.id) if isinstance(node, gast.Attribute): parent = self._resolve_name(node.value) if parent is not None: @@ -119,12 +120,12 @@ class CallTreeTransformer(transformer.Base): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] - for mod in self.uncompiled_modules: + for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): - if fqn[:i] in self.uncompiled_modules: + if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations @@ -140,7 +141,7 @@ class CallTreeTransformer(transformer.Base): if hasattr(target_entity, '__pyct_is_compile_decorator'): return False - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: return False # Inspect the target function decorators. If any include a @convert @@ -159,7 +160,7 @@ class CallTreeTransformer(transformer.Base): for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and - decorator_fn in self.nocompile_decorators): + decorator_fn in self.ctx.program.autograph_decorators): return False return True @@ -174,7 +175,7 @@ class CallTreeTransformer(transformer.Base): return node if anno.hasanno(node, 'is_constructor'): - new_name = self.context.namer.compiled_class_name( + new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: @@ -183,7 +184,7 @@ class CallTreeTransformer(transformer.Base): else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) - new_name, do_rename = self.context.namer.compiled_function_name( + new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: @@ -264,15 +265,16 @@ class CallTreeTransformer(transformer.Base): return node def visit_Call(self, node): - # If the function is wrapped by one of the marker decorators, + # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' - 'A decorator needs at least an argument.') + 'A decorator needs at least one positional argument.' % + target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) @@ -309,27 +311,20 @@ class CallTreeTransformer(transformer.Base): # ensure that they return the correct value. return node - if self.context.recursive: + if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node -def transform(node, context, uncompiled_modules, nocompile_decorators): +def transform(node, ctx): """Transform function call to the compiled counterparts. Args: - node: AST to transform. - context: An EntityContext object. - uncompiled_modules: set of string tuples, each tuple represents the fully - qualified name of a package containing functions that will not be - compiled. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST + ctx: EntityContext Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ - t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators) - node = t.visit(node) - return node + return CallTreeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py index 303dd54a4ee49de27fad0c5cdc2d6274abfe0fa8..27d8281b856f505062ceacc8ad50c8cbc2ce6c81 100644 --- a/tensorflow/contrib/autograph/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph.converters import call_trees -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CallTreesTest(converter_test_base.TestCase): +class CallTreesTest(converter_testing.TestCase): def test_basic(self): @@ -43,7 +43,7 @@ class CallTreesTest(converter_test_base.TestCase): return test_fn_1(a) + 1 node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 @@ -60,7 +60,7 @@ class CallTreesTest(converter_test_base.TestCase): return f() + 3 node = self.parse_and_analyze(test_fn_2, {}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: # 10 = 7 (from the mock) + 3 (from test_fn_2) @@ -78,9 +78,9 @@ class CallTreesTest(converter_test_base.TestCase): node = self.parse_and_analyze( TestClass.test_fn_2, {'TestClass': TestClass}, - namer=converter_test_base.FakeNoRenameNamer(), + namer=converter_testing.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: tc = TestClass() @@ -92,7 +92,7 @@ class CallTreesTest(converter_test_base.TestCase): setattr(a, 'foo', 'bar') node = self.parse_and_analyze(test_fn, {'setattr': setattr}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: with self.test_session() as sess: @@ -115,7 +115,7 @@ class CallTreesTest(converter_test_base.TestCase): return np.random.binomial(2, 0.5) node = self.parse_and_analyze(test_fn, {'np': np}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node, dtypes.int64) as result: result.np = np @@ -130,13 +130,13 @@ class CallTreesTest(converter_test_base.TestCase): a = math_ops.add(a, constant_op.constant(1)) return a - node = self.parse_and_analyze(test_fn, { - 'math_ops': math_ops, - 'constant_op': constant_op - }) - node = call_trees.transform(node, self.ctx, - set(((math_ops.__name__,), - (constant_op.__name__,))), ()) + node = self.parse_and_analyze( + test_fn, { + 'math_ops': math_ops, + 'constant_op': constant_op + }, + arg_types=set(((math_ops.__name__,), (constant_op.__name__,)))) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: result.math_ops = math_ops diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 0417817a77e706fc0ce805f7391bea600f5fbb2d..958bde0a58764e705c35ab73ce879b2c11ce7cdc 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -31,7 +31,7 @@ GUARD_CREATED = 'guard_created' CREATE_GUARD_NEXT = 'create_guard_next' -class ContinueCanonicalizationTransformer(transformer.Base): +class ContinueCanonicalizationTransformer(converter.Base): """Canonicalizes continue statements into additional conditionals.""" def visit_Continue(self, node): @@ -85,7 +85,7 @@ class ContinueCanonicalizationTransformer(transformer.Base): def _visit_loop_body(self, node, nodes): self.enter_local_scope() scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced) self.set_local(CONTROL_VAR_NAME, continue_var) nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) @@ -135,5 +135,5 @@ class ContinueCanonicalizationTransformer(transformer.Base): return node -def transform(node, namer): - return ContinueCanonicalizationTransformer(namer).visit(node) +def transform(node, ctx): + return ContinueCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index bcbb316d7459aa5a25bb0bd128cd6e359a393288..2ce1837972c50bbc4921487a290f5cb2f782b5f3 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class ContinueCanonicalizationTest(converter_test_base.TestCase): +class ContinueCanonicalizationTest(converter_testing.TestCase): def test_basic_continue(self): diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 1e718f02d10ea1a520066c74f520144feee242b9..f4a87106279d5658ecaa90a577cbe741711ba22e 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -45,9 +45,8 @@ class SymbolNamer(object): raise NotImplementedError() -class ControlFlowTransformer(transformer.Base): - """Transforms control flow structures like loops and conditionals.""" - +class ControlFlowTransformer(converter.Base): + """Transforms control flow structures like loops an conditionals.""" def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -141,10 +140,10 @@ class ControlFlowTransformer(transformer.Base): aliased_orelse_orig_names = tuple(orelse_scope.modified - orelse_scope.created) aliased_body_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) @@ -165,9 +164,8 @@ class ControlFlowTransformer(transformer.Base): else: results = gast.Tuple([s.ast() for s in modified], None) - body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.context.namer.new_symbol('if_false', - orelse_scope.referenced) + body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) if modified: def build_returns(aliased_names, alias_map, scope): @@ -235,7 +233,7 @@ class ControlFlowTransformer(transformer.Base): raise ValueError('cannot convert while loop: no outputs') state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -267,11 +265,9 @@ class ControlFlowTransformer(transformer.Base): state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - test_name=self.context.namer.new_symbol('loop_test', - body_scope.referenced), + test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), test=test, - body_name=self.context.namer.new_symbol('loop_body', - body_scope.referenced), + body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) @@ -288,7 +284,7 @@ class ControlFlowTransformer(transformer.Base): state = list(body_closure) state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -326,17 +322,16 @@ class ControlFlowTransformer(transformer.Base): state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, - extra_test_name=self.context.namer.new_symbol('extra_test', - all_referenced), + extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, - body_name=self.context.namer.new_symbol('loop_body', all_referenced), + body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node -def transform(node, context): - cfg.run_analyses(node, cfg.Liveness(context)) - cfg.run_analyses(node, cfg.Defined(context)) - node = ControlFlowTransformer(context).visit(node) +def transform(node, ctx): + cfg.run_analyses(node, cfg.Liveness(ctx.info)) + cfg.run_analyses(node, cfg.Defined(ctx.info)) + node = ControlFlowTransformer(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 9d23d9b5b7e8e8480e04fccc1c8c81799abf382b..735eb92a0dd06ee7fd621b92b1a8f894e09cee4a 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import control_flow -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -27,7 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test -class ControlFlowTest(converter_test_base.TestCase): +class ControlFlowTest(converter_testing.TestCase): def test_simple_while(self): diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py index 92445f31746cf94856ea43893f99a2ba60355fb5..3471bd11d6073f57a2703b438df95a60f19e8e0c 100644 --- a/tensorflow/contrib/autograph/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,19 +24,14 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.python.util import tf_inspect -class DecoratorsTransformer(gast.NodeTransformer): +class DecoratorsTransformer(converter.Base): """Converts or removes decorators.""" - def __init__(self, remove_decorators): - self.remove_decorators = remove_decorators - self.additional_dependencies = set() - - # pylint:disable=invalid-name - def visit_FunctionDef(self, node): self.generic_visit(node) kept_decorators = [] @@ -58,31 +53,53 @@ class DecoratorsTransformer(gast.NodeTransformer): # This is currently verified by tests. continue - if not anno.hasanno(dec_func, 'live_val'): - raise ValueError( - 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) - + original_dec = anno.getanno(dec_func, anno.Basic.QN) dec_value = anno.getanno(dec_func, 'live_val') - if dec_value not in self.remove_decorators: - kept_decorators.append((dec, dec_value)) - for _, dec_value in kept_decorators: - if dec_value.__module__ == '__main__': + if dec_value in self.ctx.program.autograph_decorators: + # AutoGraph decorators do not need to be preserved. + continue + + # When using foo.bar.baz, we only really need to grab foo and import + # that. + dec_support_node = dec_func + while isinstance(dec_support_node, gast.Attribute): + dec_support_node = dec_support_node.value + + if not anno.hasanno(dec_support_node, 'live_val'): raise ValueError( - 'decorator "%s" was not allowed because it is declared ' - 'in the module "%s". To fix this, declare it in a separate ' - 'module that we can import it from.' % (dec_value, - dec_value.__module__)) + 'could not resolve symbol "%s" when looking up decorator "%s"' % + (anno.getanno(dec_support_node, anno.Basic.QN), original_dec)) + + dec_support = anno.getanno(dec_support_node, 'live_val') + # The tuple contains: + # * the AST that represents the decorator + # * the entity supporting the decorator (i.e., what we need to import) + # * the name of the module that needs to be imported for this decorator + # to properly resolve. + # Examples: + # for foo.bar, the tuple is (, , 'foo') + # for baz, the tuple is (, , 'baz') + kept_decorators.append((dec, dec_support, + anno.getanno(dec_support_node, anno.Basic.QN))) + + for _, dec_support, name in kept_decorators: + if tf_inspect.ismodule(dec_support): + self.ctx.program.additional_imports.add( + 'import %s as %s' % (dec_support.__name__, name)) else: - self.additional_dependencies.add(dec_value) - - node.decorator_list = [dec for dec, _ in kept_decorators] + if dec_support.__module__ == '__main__': + raise ValueError( + 'decorator "%s" was not allowed because it is declared ' + 'in the module "%s". To fix this, declare it in a separate ' + 'module that we can import it from.' % (dec_support, + dec_support.__module__)) + self.ctx.program.additional_imports.add( + 'from %s import %s' % (dec_support.__module__, name)) + + node.decorator_list = [dec for dec, _, _ in kept_decorators] return node - # pylint:enable=invalid-name - -def transform(node, remove_decorators): - transformer = DecoratorsTransformer(remove_decorators) - node = transformer.visit(node) - return node, transformer.additional_dependencies +def transform(node, ctx): + return DecoratorsTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py index 9c01f689127dbedad7669c65b03e7da071b2d64d..d41c7fde2474803a438100e7e00ce8e9f675de45 100644 --- a/tensorflow/contrib/autograph/converters/decorators_test.py +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -20,9 +20,10 @@ from __future__ import print_function from functools import wraps -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -39,28 +40,35 @@ def simple_decorator(f): return lambda a: f(a) + 1 -def self_removing_decorator(removing_wrapper): +def self_transform_decorator(transform): + def decorator(f): @wraps(f) def wrapper(*args): # This removing wrapper is defined in the test below. This setup is so - # intricate just to simulate how we use the transformer in practice. - transformed_f = removing_wrapper(f, (self_removing_decorator,)) + # intricate in order to simulate how we use the transformer in practice. + transformed_f = transform(f, (self_transform_decorator,)) return transformed_f(*args) + 1 return wrapper return decorator -class DecoratorsTest(converter_test_base.TestCase): +class DecoratorsTest(converter_testing.TestCase): - def _remover_wrapper(self, f, remove_decorators): + def _transform(self, f, autograph_decorators): namespace = { - 'self_removing_decorator': self_removing_decorator, - 'simple_decorator': simple_decorator + 'self_transform_decorator': self_transform_decorator, + 'simple_decorator': simple_decorator, + 'converter_testing': converter_testing, } - node = self.parse_and_analyze(f, namespace) - node, _ = decorators.transform(node, remove_decorators=remove_decorators) - result, _ = compiler.ast_to_object(node) + node = self.parse_and_analyze( + f, + namespace, + recursive=False, + autograph_decorators=autograph_decorators) + node = decorators.transform(node, self.ctx) + import_line = '\n'.join(self.ctx.program.additional_imports) + result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__) def test_noop(self): @@ -69,15 +77,14 @@ class DecoratorsTest(converter_test_base.TestCase): return a node = self.parse_and_analyze(test_fn, {}) - node, deps = decorators.transform(node, remove_decorators=()) + node = decorators.transform(node, self.ctx) result, _ = compiler.ast_to_object(node) - self.assertFalse(deps) self.assertEqual(1, result.test_fn(1)) def test_function(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): return a @@ -88,7 +95,7 @@ class DecoratorsTest(converter_test_base.TestCase): class TestClass(object): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(self, a): return a @@ -101,38 +108,39 @@ class DecoratorsTest(converter_test_base.TestCase): # Note that reversing the order of this two doesn't work. @classmethod - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(cls, a): return a # 2 = 1 (a) + 1 (decorator applied exactly once) self.assertEqual(2, TestClass.test_fn(1)) - def test_nested_decorators(self): + def test_nested_decorators_local(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): @simple_decorator def inner_fn(b): return b + 11 return inner_fn(a) - with self.assertRaises(ValueError): + # Expected to fail because simple_decorator cannot be imported. + with self.assertRaises(transformer.AutographParseError): test_fn(1) - # TODO(mdan): Uncomment this test once converter_test_base is updated. - # (can't do it now because it has unrelated pending changes) - # def test_nested_decorators(self): - # - # @self_removing_decorator(self._remover_wrapper) - # def test_fn(a): - # @imported_decorator - # def inner_fn(b): - # return b + 11 - # return inner_fn(a) - # - # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) - # self.assertEqual(14, test_fn(1)) + def test_nested_decorators_imported(self): + + @self_transform_decorator(self._transform) + def test_fn(a): + + @converter_testing.imported_decorator + def inner_fn(b): + return b + 11 + + return inner_fn(a) + + # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) + self.assertEqual(14, test_fn(1)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py index 616d222762e09feeba1809f119d915dfbe522283..e996138498ab2b7efa76671d8cc67fd4c6a9d9b8 100644 --- a/tensorflow/contrib/autograph/converters/ifexp.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class IfExp(transformer.Base): +class IfExp(converter.Base): """Canonicalizes all IfExp nodes into plain conditionals.""" def visit_IfExp(self, node): @@ -34,16 +34,16 @@ class IfExp(transformer.Base): return desugared_ifexp -def transform(node, context): +def transform(node, ctx): """Desugar IfExp nodes into plain conditionals. Args: - node: an AST node to transform - context: a context object + node: ast.AST, the node to transform + ctx: converter.EntityContext Returns: new_node: an AST with no IfExp nodes, only conditionals. """ - node = IfExp(context).visit(node) + node = IfExp(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py index ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e..cdd5a2f591edc1138df1c165577ed375131ddf09 100644 --- a/tensorflow/contrib/autograph/converters/ifexp_test.py +++ b/tensorflow/contrib/autograph/converters/ifexp_test.py @@ -19,12 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class IfExpTest(converter_test_base.TestCase): +class IfExpTest(converter_testing.TestCase): def compiled_fn(self, test_fn, *args): node = self.parse_and_analyze(test_fn, {}) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py index d7f292015164e047d054c5d1fb0b391e960bb73d..c4a13ee822ab84706df83256d9e9684c3f7dacba 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension.py @@ -31,17 +31,14 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class ListCompCanonicalizationTransformer(transformer.Base): +class ListCompCanonicalizationTransformer(converter.Base): """NodeTransformer to canonicalize list comprehensions.""" - def __init__(self, context): - super(ListCompCanonicalizationTransformer, self).__init__(context) - def make_update_list_node(self, list_, elt): return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] @@ -76,5 +73,5 @@ class ListCompCanonicalizationTransformer(transformer.Base): return make_list + loop_body -def transform(node, context): - return ListCompCanonicalizationTransformer(context).visit(node) +def transform(node, ctx): + return ListCompCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py index 4758671f5ec83c26cfa54be0ef68f5f564094f6c..2bbee93412ce3174a14f3d60af9435dcf3b82cc6 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py @@ -18,12 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import list_comprehension +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class ListCompTest(converter_test_base.TestCase): +class ListCompTest(converter_testing.TestCase): def test_basic(self): diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index c15dfff9e8ebd8b96fd4aff82459a6fd7d0ac8ab..d77a04479826779b8aa859d70f2f7ff51138f841 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -32,10 +32,10 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -43,7 +43,7 @@ from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno POP_USES = 'pop_uses' -class ListTransformer(transformer.Base): +class ListTransformer(converter.Base): """Converts lists and related operations to their TF counterpart.""" def visit_List(self, node): @@ -94,7 +94,7 @@ class ListTransformer(transformer.Base): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: target_name = 'list' - pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced) + pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) pop_uses = self.get_local(POP_USES, []) pop_uses.append((node, pop_var_name)) @@ -223,5 +223,5 @@ class ListTransformer(transformer.Base): return node -def transform(node, context): - return ListTransformer(context).visit(node) +def transform(node, ctx): + return ListTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 9f18ab9f44dd8c3f341a02b950f75317c676eff8..ea04097b28deedd705164bd95ab62dba3e3c7834 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -class ListTest(converter_test_base.TestCase): +class ListTest(converter_testing.TestCase): def test_empty_list(self): diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py index 3a795a315a3c2aa08ac1577a204102755b6e849c..16eb1f0e3f8ad34e615931882ab2896db485f457 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -23,10 +23,10 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer # TODO(mdan): Properly extrack boolean ops according to lazy eval rules. @@ -39,11 +39,11 @@ from tensorflow.contrib.autograph.pyct import transformer SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' -class LogicalExpressionTransformer(transformer.Base): +class LogicalExpressionTransformer(converter.Base): """Converts logical expressions to corresponding TF calls.""" - def __init__(self, context): - super(LogicalExpressionTransformer, self).__init__(context) + def __init__(self, ctx): + super(LogicalExpressionTransformer, self).__init__(ctx) # TODO(mdan): Look into replacing with bitwise operators instead. # TODO(mdan): Skip replacing if the function is trivial. self.op_mapping = { @@ -128,5 +128,5 @@ class LogicalExpressionTransformer(transformer.Base): return right -def transform(node, context): - return LogicalExpressionTransformer(context).visit(node) +def transform(node, ctx): + return LogicalExpressionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 2814060c4d831e4dddacb3dcbcbe1db42160db20..48186024a9da7b41fa7ff9a8ab18f3477ba09c8f 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class GradientsFunctionTest(converter_test_base.TestCase): +class GradientsFunctionTest(converter_testing.TestCase): def test_equals(self): diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py index dfee529abaa8c14d9b408819b32c5199500a2c2f..dd6c6bf960c52d094a16d4cd72fa84f65b9322a1 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes.py +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class FunctionNameScopeTransformer(transformer.Base): +class FunctionNameScopeTransformer(converter.Base): """Wrap a function body with a `name_scope` of the function name.""" def _name_for_current_scope(self): @@ -70,5 +70,5 @@ class FunctionNameScopeTransformer(transformer.Base): return node -def transform(node, context): - return FunctionNameScopeTransformer(context).visit(node) +def transform(node, ctx): + return FunctionNameScopeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py index 17692cbd880dbc1db4bb40ad7345e27907499f9d..444d0bcd469f35689d078debe3622f930dbac723 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test -class FunctionNameScopeTransformer(converter_test_base.TestCase): +class FunctionNameScopeTransformer(converter_testing.TestCase): def test_basic(self): diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py index 3bcb2d3c42c6e0663c8f78523199a364b6ac231f..b808604f0ab2d42f41a560035ab046ff782a3431 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,11 +36,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -59,14 +59,9 @@ class SymbolNamer(object): raise NotImplementedError() -class SideEffectGuardTransformer(transformer.Base): +class SideEffectGuardTransformer(converter.Base): """Adds control dependencies to functions with side effects.""" - def __init__(self, context): - super(SideEffectGuardTransformer, self).__init__(context) - - # pylint:disable=invalid-name - def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes @@ -149,7 +144,7 @@ class SideEffectGuardTransformer(transformer.Base): s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( - self.context.namer.new_symbol( + self.ctx.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: @@ -183,8 +178,6 @@ class SideEffectGuardTransformer(transformer.Base): (node.body, alias_map)) return node - # pylint:enable=invalid-name - -def transform(node, context): - return SideEffectGuardTransformer(context).visit(node) +def transform(node, ctx): + return SideEffectGuardTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py index ce0ce33243a1352107eb8121050ee76474869809..a7ad8efed4c88e15ce9dc14cb02e5e035602013d 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class SideEffectGuardsTest(converter_test_base.TestCase): +class SideEffectGuardsTest(converter_testing.TestCase): def test_side_effect_on_return_only_variable(self): diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py index bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7..a351cd81b82f7fb32f62ac1579355ace0501759d 100644 --- a/tensorflow/contrib/autograph/converters/single_return.py +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -20,21 +20,21 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Move this logic into transformer_base. -class BodyVisitor(transformer.Base): +class BodyVisitor(converter.Base): """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes.""" - def __init__(self, context, depth_first=False): + def __init__(self, ctx, depth_first=False): + super(BodyVisitor, self).__init__(ctx) self.depth_first = depth_first self.changes_made = False - super(BodyVisitor, self).__init__(context) def visit_nodelist(self, nodelist): for node in nodelist: @@ -144,13 +144,13 @@ def contains_return(node): return False -class LiftReturn(transformer.Base): +class LiftReturn(converter.Base): """Move return statements out of If and With blocks.""" - def __init__(self, context): + def __init__(self, ctx): + super(LiftReturn, self).__init__(ctx) self.changes_made = False self.common_return_name = None - super(LiftReturn, self).__init__(context) def visit_If(self, node): # Depth-first traversal of if statements @@ -195,8 +195,8 @@ class LiftReturn(transformer.Base): last_return_name = self.common_return_name body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) referenced_names = body_scope.referenced - self.common_return_name = self.context.namer.new_symbol( - 'return_', referenced_names) + self.common_return_name = self.ctx.namer.new_symbol('return_', + referenced_names) node = self.generic_visit(node) self.common_return_name = last_return_name return node @@ -265,7 +265,7 @@ class DetectReturnInFunctionDef(gast.NodeVisitor): 'Each function definition should contain at least one return.') -def transform(node, context): +def transform(node, ctx): """Ensure a function has only a single return. This transforms an AST node with multiple returns successively into containing @@ -280,8 +280,8 @@ def transform(node, context): this is an error. Args: - node: an AST node to transform - context: a context object + node: ast.AST + ctx: converter.EntityContext Returns: new_node: an AST with a single return value @@ -301,10 +301,10 @@ def transform(node, context): while True: # Try to lift all returns out of if statements and with blocks - lr = LiftReturn(context) + lr = LiftReturn(ctx) node = lr.visit(node) changes_made = lr.changes_made - fe = FoldElse(context) + fe = FoldElse(ctx) node = fe.visit(node) changes_made = changes_made or fe.changes_made diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py index d483005a09537ea8227814f65aa7e6402c853f60..1f0de4310e370235a4a7bfeaa61bd519a81aff47 100644 --- a/tensorflow/contrib/autograph/converters/single_return_test.py +++ b/tensorflow/contrib/autograph/converters/single_return_test.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework.ops import name_scope from tensorflow.python.platform import test -class SingleReturnTest(converter_test_base.TestCase): +class SingleReturnTest(converter_testing.TestCase): def compiled_fn(self, test_fn, *args): node = self.parse_and_analyze(test_fn, {}) diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py index 85aeda9c4164eb70329bd50f789eea5441c8fc87..3f5fc57125a8b65faf1e3a377d7984ff05b3245c 100644 --- a/tensorflow/contrib/autograph/converters/slices.py +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -20,12 +20,12 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class SliceTransformer(transformer.Base): +class SliceTransformer(converter.Base): """Converts slicing operations to their TF counterpart. Currently, relying on the default slice operator that Tensor uses is @@ -79,5 +79,5 @@ class SliceTransformer(transformer.Base): template, target=node.value, key=node.slice, dtype=dtype) -def transform(node, context): - return SliceTransformer(context).visit(node) +def transform(node, ctx): + return SliceTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py index 6c2d7e1ea1a6c46fcc3a2c6972a24507646ef858..df9a4c8bab66f24374605b45bc90bc2730431323 100644 --- a/tensorflow/contrib/autograph/converters/slices_test.py +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -19,15 +19,15 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import slices +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -class SliceTest(converter_test_base.TestCase): +class SliceTest(converter_testing.TestCase): def test_index_access(self): diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/contrib/autograph/core/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..833f9dced81bd651244d281322c830bb1c88b259 --- /dev/null +++ b/tensorflow/contrib/autograph/core/BUILD @@ -0,0 +1,59 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "core", + srcs = [ + "config.py", + "converter.py", + "naming.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + ], +) + +py_library( + name = "test_lib", + srcs = [ + "converter_testing.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":core", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) + +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/core/config.py similarity index 100% rename from tensorflow/contrib/autograph/impl/config.py rename to tensorflow/contrib/autograph/core/config.py diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..54e6aa0f3bbb9059e044861362407cb5050240b4 --- /dev/null +++ b/tensorflow/contrib/autograph/core/converter.py @@ -0,0 +1,210 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter construction support. + +This module contains a base class for all converters, as well as supporting +structures. These structures are referred to as contexts. + +The class hierarchy is as follows: + + + [extends] converter.Base + [extends] transformer.Base + [extends] gast.nodeTransformer + [uses] transfomer.SourceInfo + [uses] converter.EntityContext + [uses] converter.ProgramContext + [uses] transfomer.SourceInfo + +converter.Base is a specialization of transformer.Base for AutoGraph. It's a +very lightweight subclass that adds a `ctx` attribute holding the corresponding +EntityContext object (see below). Note that converters are not reusable, and +`visit` will raise an error if called more than once. + +converter.EntityContext contains mutable state associated with an entity that +the converter processes. + +converter.ProgramContext contains mutable state across related entities. For +example, when converting several functions that call one another, the +ProgramContext should be shared across these entities. + +Below is the overal flow at conversion: + + program_ctx = ProgramContext(, , ...) + while : + entity, source_info = + entity_ctx = EntityContext(program_ctx, source_info) + for : + converter = ConverterClass(entity_ctx) + + # May update entity_ctx and program_ctx + entity = converter.visit(entity) + + + +Note that pyct contains a small number of transformers used for static analysis. +These implement transformer.Base, rather than converter.Base, to avoid a +dependency on AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import naming +from tensorflow.contrib.autograph.pyct import transformer + +# TODO(mdan): These contexts can be refactored into first class objects. +# For example, we could define Program and Entity abstractions that hold on +# to the actual entity and have conversion methods. + + +class ProgramContext(object): + """ProgramContext keeps track of converting function hierarchies. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + recursive: bool, whether to recursively convert any functions that the + decorator function may call. + autograph_decorators: Tuple[Callable, ...], decorator functions that belong + to AutoGraph. These require special treatment. + dependency_cache: Dict[Any, ast.AST], the original entities mapped to their + converted AST + additional_imports: Set[Any], additional entities which for any reason + cannot be attached after loading and need to be explicitly imported + in the generated code + name_map: Dict[str, str], map of original entity name to the name of + their converted counterparts + autograph_module: Module, a reference to the autograph module. This + needs to be specified by the caller to avoid circular dependencies. + uncompiled_modules: Set[Tuple[str, ...]], with each tuple representing the + fully qualified name of a package containing functions that will not be + compiled. + required_imports: str, containing an import statement on each line. These + are all the imports necessary for the compiled code to run, in addition + to the closures of each entity, which are attached dynamically. + """ + + def __init__( + self, + recursive, + autograph_decorators, + partial_types, + autograph_module, + uncompiled_modules, + ): + self.recursive = recursive + self.autograph_decorators = autograph_decorators + self.partial_types = partial_types if partial_types else () + self.autograph_module = autograph_module + self.uncompiled_modules = uncompiled_modules + + # Required to output dependencies in discovery order, which should match + # the reverse dependency order. + self.dependency_cache = collections.OrderedDict() + self.additional_imports = set() + self.name_map = {} + + @property + def required_imports(self): + """Returns a block containing all imports required by the converted code.""" + # TODO(mdan): Check that these don't clobber one another. + return '\n'.join(config.COMPILED_IMPORT_STATEMENTS + + tuple(self.additional_imports)) + + def new_namer(self, namespace): + return naming.Namer(namespace, self.recursive, self.name_map, + self.partial_types) + + def update_name_map(self, namer): + """Updates renamed_calls based on the recent activity from the namer. + + Whenever we convert a new entity, any references to other entities are being + renamed to match their soon-to-be-converted counterparts. The namer keeps + track of these renames. When conversion is complete, we copy those renames + so that when those referenced entities are being converted, their new name + matches. + + Args: + namer: naming.Namer + + Raises: + ValueError: when an entity was renamed twice and to different names. + """ + # TODO(mdan): Have call_trees do this directly. + # This is done so indirectly, via the namer, for historic reasons. But + # now we can have the converter that does the rename record the new name + # as well and skip this step altogether. + for o, name in namer.renamed_calls.items(): + if o in self.name_map: + if self.name_map[o] != name: + raise ValueError( + 'Calls to %s were converted using multiple names (%s). This is ' + 'possible when an entity with one of these names already ' + 'existed. To fix, avoid using any of these names.' % + (o, (name, self.name_map[o]))) + else: + self.name_map[o] = name + + def add_to_cache(self, original_entity, converted_ast): + self.dependency_cache[original_entity] = converted_ast + + +class EntityContext(object): + """Tracks the conversion of a single entity. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + namer: Namer + info: transformer.EntityInfo + program: ProgramContext + """ + + def __init__(self, namer, entity_info, program_ctx): + self.namer = namer + self.info = entity_info + self.program = program_ctx + + +class Base(transformer.Base): + """All converters should inherit from this class. + + Attributes: + ctx: EntityContext + """ + + def __init__(self, ctx): + super(Base, self).__init__(ctx.info) + self.ctx = ctx # Keeping this short because it's used frequently. + + self._used = False + self._ast_depth = 0 + + def visit(self, node): + if not self._ast_depth: + if self._used: + raise ValueError('converter objects cannot be reused') + self._used = True + + self._ast_depth += 1 + try: + return super(Base, self).visit(node) + finally: + self._ast_depth -= 1 diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/core/converter_testing.py similarity index 80% rename from tensorflow/contrib/autograph/converters/converter_test_base.py rename to tensorflow/contrib/autograph/core/converter_testing.py index 41c2e71702e7e3ee3811a2cbee27c8c988eb3a5c..0e46aacc1216d2dbd9d34ad0e72ca8251094bddc 100644 --- a/tensorflow/contrib/autograph/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -23,17 +23,24 @@ import imp from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import compiler -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import pretty_printer from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test +def imported_decorator(f): + return lambda a: f(a) + 1 + + +# TODO(mdan): We might be able to use the real namer here. class FakeNamer(object): """A fake namer that uses a global counter to generate unique names.""" @@ -114,23 +121,32 @@ class TestCase(test.TestCase): arg_types=None, include_type_analysis=True, owner_type=None, - recursive=True): + recursive=True, + autograph_decorators=()): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=namer or FakeNamer(), + + if namer is None: + namer = FakeNamer() + program_ctx = converter.ProgramContext( + recursive=recursive, + autograph_decorators=autograph_decorators, + partial_types=None, + autograph_module=None, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + entity_info = transformer.EntityInfo( source_code=source, - source_file=None, + source_file='', namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=owner_type, - recursive=recursive, - type_annotation_func=utils.set_element_type) + owner_type=owner_type) + ctx = converter.EntityContext(namer, entity_info, program_ctx) + node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) if include_type_analysis: - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) self.ctx = ctx return node diff --git a/tensorflow/contrib/autograph/impl/naming.py b/tensorflow/contrib/autograph/core/naming.py similarity index 100% rename from tensorflow/contrib/autograph/impl/naming.py rename to tensorflow/contrib/autograph/core/naming.py diff --git a/tensorflow/contrib/autograph/impl/naming_test.py b/tensorflow/contrib/autograph/core/naming_test.py similarity index 98% rename from tensorflow/contrib/autograph/impl/naming_test.py rename to tensorflow/contrib/autograph/core/naming_test.py index 73fc0894655cb49e4f61bf8ca51995b06feb3072..d2bebd0478b1074e421b5da1427a0dbaf91b6c9f 100644 --- a/tensorflow/contrib/autograph/impl/naming_test.py +++ b/tensorflow/contrib/autograph/core/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.core import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb index d62390494b78c415212ba91ac914cdfee324f971..0702273fac15da61a72d66d8344a5add32ad12a6 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -570,7 +570,7 @@ " autograph.utils.set_element_type(numbers, tf.int32)\n", " for i in range(n):\n", " numbers.append(i)\n", - " return numbers.stack() # Stack the list so that it can be used as a Tensor\n", + " return autograph.stack(numbers) # Stack the list so that it can be used as a Tensor\n", "\n", "\n", "tf_f = autograph.to_graph(f)\n", @@ -648,7 +648,7 @@ " if not is_prime:\n", " continue\n", " primes.append(i)\n", - " all_primes = primes.stack()\n", + " all_primes = autograph.stack(primes)\n", "\n", " print('The prime numbers less than', n, 'are:')\n", " print(all_primes)\n", @@ -953,8 +953,9 @@ " train_accuracies.append(step_train_accuracy)\n", " test_accuracies.append(step_test_accuracy)\n", " i += 1\n", - " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n", - " test_accuracies.stack())" + " return (autograph.stack(train_losses), autograph.stack(test_losses),\n", + " autograph.stack(train_accuracies),\n", + " autograph.stack(test_accuracies))" ], "execution_count": 0, "outputs": [] @@ -1236,7 +1237,7 @@ " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", " i += 1\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 324b23c24b5a7970d7f20ed955839ba1cf1774fc..44532cb078f9bd1578172f8a7d8a4b55cd21a7cb 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -190,7 +190,6 @@ " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", - "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", " \"\"\"A single RNN layer.\n", "\n", @@ -203,13 +202,12 @@ " Returns:\n", " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", " \"\"\"\n", - " hidden_outputs = []\n", - " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " hidden_outputs = tf.TensorArray(tf.float32, 0, True)\n", " state, output = cell.zero_state(batch_size, tf.float32)\n", " for ch in chars:\n", " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", @@ -223,7 +221,7 @@ "\n", "\n", " def call(self, inputs, training=False):\n", - " \"\"\"The RNN model code. Uses Eager and \n", + " \"\"\"The RNN model code. Uses Eager.\n", "\n", " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", " followed by a fully connected layer with ReLU activation.\n", @@ -243,7 +241,8 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " indices = (length - 1, range(batch_size))\n", + " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", "\n", @@ -381,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 107, "metadata": { "colab": { "autoexec": { @@ -392,9 +391,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 10604, + "elapsed": 5454, "status": "ok", - "timestamp": 1524095272039, + "timestamp": 1529952160455, "user": { "displayName": "", "photoUrl": "", @@ -403,7 +402,7 @@ "user_tz": 240 }, "id": "2pg1AfbxBJQq", - "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660", + "outputId": "4aef3052-f7c7-4bb1-a0a2-73fef2e96efb", "slideshow": { "slide_type": "-" } @@ -413,7 +412,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Eval loss at step 100: 0.0674834\n" + "Eval loss at step 100: 0.0705221\n" ] } ], @@ -423,8 +422,8 @@ " 'learning_rate': 0.01,\n", "}\n", "\n", - "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", - "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n", "data_dir = \"tmp/rnn/data\"\n", "\n", "regressor = tf.estimator.Estimator(\n", @@ -457,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 108, "metadata": { "colab": { "autoexec": { @@ -468,9 +467,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 7990, + "elapsed": 3432, "status": "ok", - "timestamp": 1524095280105, + "timestamp": 1529952163923, "user": { "displayName": "", "photoUrl": "", @@ -479,7 +478,7 @@ "user_tz": 240 }, "id": "dxHex2tUN_10", - "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101", + "outputId": "1ff438f2-b045-4f4e-86a0-4dae7503f6b2", "slideshow": { "slide_type": "slide" } @@ -491,12 +490,12 @@ "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -507,12 +506,12 @@ "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a8d0\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -520,15 +519,15 @@ { "data": { "text/html": [ - "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + "\u003cdiv id=\"id3\"\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a050\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -536,16 +535,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n", - "//# sourceURL=js_71b9087b6d" + "window[\"8a03307e-78a7-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id3\", \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0, \"borderColor\": [\"#a7a7a7\"]});\n", + "//# sourceURL=js_dc5d7f2784" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a190\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -553,16 +552,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_e390445f33" + "window[\"8a03307f-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_be7950150b" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ac90\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -570,17 +569,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_241dd76d85" + "window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_d0c3bd4eaa" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aad0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -588,17 +587,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_60c64e3d50" + "window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_f10f6eba86" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aed0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -606,17 +605,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_14ea437cbd" + "window[\"8a033082-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff29697179" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222abd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -624,17 +623,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_09294c2226" + "window[\"8a033083-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_ff85295dc7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ab90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -642,17 +641,17 @@ { "data": { "application/javascript": [ - "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_e5e8266997" + "window[\"8b18d8dc-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ed7aabfedb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -660,17 +659,17 @@ { "data": { "application/javascript": [ - "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_07a097f0ee" + "window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_c86f8feaf4" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222acd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -678,17 +677,17 @@ { "data": { "application/javascript": [ - "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_790d669ca8" + "window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_4d0fde6662" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ae50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -696,17 +695,17 @@ { "data": { "application/javascript": [ - "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_d30df771f0" + "window[\"8b18d8df-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3f66d52720" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a210\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -714,32 +713,32 @@ { "data": { "application/javascript": [ - "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_8a43a2da4b" + "window[\"8b18d8e0-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_375f5ae6d7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABTFJREFUeJzt3C+LV30eh/HP6EZvbP4ZJmkXDA6oQdZRMIhYLIKCMGVA\nyyaLT2ERLMqEDfoUFA2y3WpRrOKoSUSECePcYUEWdsN1OzfOyr5e8ZwT3unie34cfgvb29vbAxDs\n2e0BwK9DMIBMMIBMMIBMMIBMMIBMMPipXrx4MWfOnNntGfwgweCnW1hY2O0J/CDBYEe2trZ2ewI/\nkWDwh509e3bW19fn0qVLc/z48dnY2Jhbt27NyZMn59y5c/Pw4cPvz25ubs7t27dneXl5Ll68OC9f\nvtzF5ezUX3Z7AL+mJ0+ezPr6+uzfv3+uXr0658+fn7t3787GxsbcuHFjjhw5MqdPn5579+7N27dv\n5/nz5/P169dZXV3d7ensgBMGP+T69etz8ODBef369Xz69GnW1tZm7969s7S0NFeuXJnHjx/PzMzT\np09nbW1tfvvttzl48OBcu3Ztl5ezE04Y/JBDhw7NzMy7d+/mw4cPs7y8PDMz29vb8+3btzlx4sTM\nzHz8+PH7szMzi4uLP38sfxrBYEcOHz48S0tL8+zZs/96/8CBA7OxsTFHjx6dmX8Fhl+XVxJ25Nix\nY7Nv375ZX1+fzc3N2dramjdv3nz/cfPChQvz4MGD+fz587x//34ePXq0y4vZCcHgD/v37yj27Nkz\n9+/fn1evXs3KysqcOnVq7ty5M1++fJmZmZs3b87i4uKsrKzM6urqXL58ebdm8ydY8Ac6QOWEAWSC\nAWSCAWSCAWT/s99h/P3GX3d7Avxf+9s//vkf15wwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgGxhe3t7e7dHAL8GJwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg\nEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg+x1QoZHG4XIe4gAAAABJRU5ErkJggg==\n", "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e" + "\u003cmatplotlib.figure.Figure at 0x7fcd0d02dc90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -748,17 +747,17 @@ { "data": { "application/javascript": [ - "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_893ad561f4" + "window[\"8b18d8e1-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_34b0509660" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e850\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -766,17 +765,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_2d99e0ac17" + "window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_518a0f26fe" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ec90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -784,17 +783,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_5c19462e32" + "window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_17eb3ff612" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -802,17 +801,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_b9c8b7567b" + "window[\"8b18d8e4-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_99da807c8e" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -820,17 +819,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_fd05186348" + "window[\"8b18d8e5-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_dee01cb4b6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e610\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -838,16 +837,16 @@ { "data": { "text/html": [ - "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + "\u003cdiv class=id_853612217 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222aa10\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -856,17 +855,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_efef96e882" + "window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_8c378be329" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e990\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -875,17 +874,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_6eca889864" + "window[\"8b18d8e7-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_f0b946600c" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -894,17 +893,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n", - "//# sourceURL=js_f02070cc60" + "window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 input\");\n", + "//# sourceURL=js_9e21b1373a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ea90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -913,17 +912,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", - "//# sourceURL=js_ed9faba660" + "window[\"8b18d8ea-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_a7764968c6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e5d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -932,17 +931,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_f3458d7074" + "window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_74279d3ff0" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e890\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -951,17 +950,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_3ffd97bd6f" + "window[\"8b18d8ec-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_82b6c34cdb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -970,17 +969,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_7f73e8bcca" + "window[\"8b18d8ed-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff6144734a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -1043,28 +1042,6 @@ "kind": "local" }, "name": "RNN Colorbot using Keras and Estimators", - "provenance": [ - { - "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl", - "timestamp": 1523579810961 - }, - { - "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", - "timestamp": 1523016192637 - }, - { - "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", - "timestamp": 1522238054357 - }, - { - "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", - "timestamp": 1521743157199 - }, - { - "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", - "timestamp": 1520522344607 - } - ], "version": "0.3.2", "views": {} }, diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 02f16ae1875d6bd1fb87d19f8bfc5cae900391dd..a5438592c30021eac7183b65ccc10c36d220bc57 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -18,20 +18,19 @@ py_library( name = "impl", srcs = [ "api.py", - "config.py", "conversion.py", - "directives.py", - "naming.py", - "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/core", "//tensorflow/contrib/autograph/operators", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:platform", + "//tensorflow/python:util", "@gast_archive//:gast", "@six_archive//:six", ], @@ -61,23 +60,3 @@ py_test( "@gast_archive//:gast", ], ) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":impl", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "special_functions_test", - srcs = ["special_functions_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":impl", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 24f87b2c14da4a3523f1e580d4362cbd3679a2cd..c7401c7df126b73ca22cdaf74a2f1fd6149d7545 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -27,14 +27,15 @@ import gast import six # pylint:enable=g-bad-import-order -from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils -from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect # TODO(mdan): Properly document the type hints. @@ -70,6 +71,8 @@ def convert(recursive=False, verbose=False, arg_types=None): def wrapper(*args, **kwargs): return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + wrapper = tf_decorator.make_decorator(f, wrapper) + # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. setattr(wrapper, '__pyct_is_compile_decorator', True) @@ -230,20 +233,20 @@ def to_graph(e, A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - _, name, namespace = conversion.entity_to_graph(e, conversion_map, arg_values, + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) - for import_line in config.COMPILED_IMPORT_STATEMENTS: - module.body.extend(parser.parse_str(import_line).body) - for dep in reversed(conversion_map.dependency_cache.values()): + for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) - compiled_node, compiled_src = compiler.ast_to_object(module) + compiled_node, compiled_src = compiler.ast_to_object( + module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? @@ -280,17 +283,16 @@ def to_code(e, Returns: String. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) code = '\n'.join( compiler.ast_to_source(dep, indentation) - for dep in reversed(tuple( - six.itervalues(conversion_map.dependency_cache)))) + for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) - return imports + '\n\n' + code + return program_ctx.required_imports + '\n\n' + code diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index a7737b7f448131b1c54951efa719b481e1f4d0c9..994309333209586001c9369322ec3ddeee0a508e 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -21,12 +21,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.impl import api -from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect tf = utils.fake_tf() @@ -154,6 +155,22 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) + def test_decorator_preserves_argspec(self): + + class TestClass(object): + + def called_member(self, a): + if a < 0: + a = -a + return a + + called_member_converted = api.convert()(called_member) + + tc = TestClass() + self.assertListEqual( + list(tf_inspect.getfullargspec(tc.called_member)), + list(tf_inspect.getfullargspec(tc.called_member_converted))) + def test_convert_call_site_decorator(self): class TestClass(object): diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 7802bbbe27ec5fed891440af2f589801918b3bdd..776d19f672ebbd6b88985dda157434f2046d87e7 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""High level conversion support.""" +"""Core conversion logic, serves as main point of access.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import imp import gast @@ -39,77 +38,22 @@ from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return from tensorflow.contrib.autograph.converters import slices -from tensorflow.contrib.autograph.impl import config -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import ast_util -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info -from tensorflow.contrib.autograph.utils import type_hints from tensorflow.python.util import tf_inspect # TODO(mdan): Might we not need any renaming at all? -class ConversionMap(object): - """ConversionMap keeps track of converting function hierarchies. - - This object is mutable, and is updated as functions are converted. - - Attributes: - recursive: Whether to recursively convert any functions that the decorator - function may call. - nocompile_decorators: tuple of decorator functions that toggle compilation - off. - dependency_cache: dict[object]: ast; maps original entities to their - converted AST - additional_imports: set(object); additional entities which for any reason - cannot be attached after loading and need to be explicitly imported - in the generated code - name_map: dict[string]: string; maps original entities to the name of - their converted counterparts - api_module: A reference to the api module. The reference needs to be passed - to avoid circular dependencies. - """ - - # TODO(mdan): Rename to ConversionContext, and pull in additional flags. - - def __init__(self, recursive, nocompile_decorators, partial_types, - api_module): - self.recursive = recursive - self.nocompile_decorators = nocompile_decorators - self.partial_types = partial_types if partial_types else () - # Required to output dependencies in discovery order, which should match - # the reverse dependency order. - self.dependency_cache = collections.OrderedDict() - self.additional_imports = set() - self.name_map = {} - self.api_module = api_module - - def new_namer(self, namespace): - return naming.Namer(namespace, self.recursive, self.name_map, - self.partial_types) - - def update_name_map(self, namer): - for o, name in namer.renamed_calls.items(): - if o in self.name_map: - if self.name_map[o] != name: - raise ValueError( - 'Calls to %s were converted using multiple names (%s). This is ' - 'possible when an entity with one of these names already ' - 'existed. To fix, avoid using any of these names.') - else: - self.name_map[o] = name - - def add_to_cache(self, original_entity, converted_ast): - self.dependency_cache[original_entity] = converted_ast - - def is_whitelisted_for_graph(o): """Check whether an entity is whitelisted for use in graph mode. @@ -128,7 +72,7 @@ def is_whitelisted_for_graph(o): return False -def entity_to_graph(o, conversion_map, arg_values, arg_types): +def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` @@ -139,7 +83,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): Args: o: A Python entity. - conversion_map: A ConversionMap object. + program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function @@ -157,7 +101,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): ValueError: if the entity type is not supported. """ if tf_inspect.isclass(o): - node, name, ns = class_to_graph(o, conversion_map) + node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): # TODO(mdan): This is not a reliable mechanism. # The most reliable way is to check the source code, the AST will contain @@ -167,36 +111,35 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): 'lambda functions are not yet supported; declare the function' ' using def instead: %s' % o) else: - node, name, ns = function_to_graph(o, conversion_map, arg_values, - arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): - node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) - conversion_map.add_to_cache(o, node) - if conversion_map.recursive: + program_ctx.add_to_cache(o, node) + if program_ctx.recursive: while True: candidate = None - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: + for obj in program_ctx.name_map.keys(): + if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and - getattr(candidate, 'im_class') not in conversion_map.partial_types): + getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue - entity_to_graph(candidate, conversion_map, {}, {}) + entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns -def class_to_graph(c, conversion_map): +def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) @@ -211,7 +154,7 @@ def class_to_graph(c, conversion_map): continue node, _, namespace = function_to_graph( m, - conversion_map=conversion_map, + program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) @@ -220,14 +163,14 @@ def class_to_graph(c, conversion_map): else: class_namespace.update(namespace) converted_members[m] = node - namer = conversion_map.new_namer(class_namespace) + namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by - # conversion_map.update_name_map(namer)). + # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} bases = [] @@ -247,7 +190,7 @@ def class_to_graph(c, conversion_map): alias = namer.compiled_class_name(base.__name__, base) bases.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) - conversion_map.update_name_map(namer) + program_ctx.update_name_map(namer) # Generate the definition of the converted class. output_nodes.append( @@ -279,14 +222,14 @@ def _add_reserved_symbol(namespace, name, entity): ag_internal = None -def _add_self_references(namespace, api_module): +def _add_self_references(namespace, autograph_module): """Adds namespace references to the module that exposes the api itself.""" global ag_internal if ag_internal is None: # Craft a module that exposes parts of the external API as well as certain # internal modules. ag_internal = imp.new_module('autograph') - ag_internal.converted_call = api_module.converted_call + ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be @@ -296,27 +239,24 @@ def _add_self_references(namespace, api_module): _add_reserved_symbol(namespace, 'ag__', ag_internal) -def function_to_graph(f, conversion_map, arg_values, arg_types, - owner_type=None): +def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] namespace = inspect_utils.getnamespace(f) - _add_self_references(namespace, conversion_map.api_module) - namer = conversion_map.new_namer(namespace) + _add_self_references(namespace, program_ctx.autograph_module) + namer = program_ctx.new_namer(namespace) - ctx = context.EntityContext( - namer=namer, + entity_info = transformer.EntityInfo( source_code=source, source_file='', namespace=namespace, arg_values=arg_values, arg_types=arg_types, - owner_type=owner_type, - recursive=conversion_map.recursive, - type_annotation_func=type_hints.set_element_type) - node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) + owner_type=owner_type) + context = converter.EntityContext(namer, entity_info, program_ctx) + node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) @@ -326,29 +266,28 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name - conversion_map.update_name_map(namer) + program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - conversion_map.additional_imports.update(deps) return node, new_name, namespace -def _static_analysis_pass(node, ctx): +def _apply_transformer(node, context, converter_module): + # TODO(mdan): Clear static analysis here. node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = live_values.resolve(node, ctx, config.PYTHON_LITERALS) - node = type_info.resolve(node, ctx) + node = activity.resolve(node, context.info, None) + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + node = type_info.resolve(node, context.info) + node = converter_module.transform(node, context) return node -def node_to_graph(node, ctx, nocompile_decorators): +def node_to_graph(node, context): """Convert Python code to equivalent TF graph mode code. Args: - node: A Python AST node representing the code to convert. - ctx: An EntityContext object. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST, the code to convert. + context: converter.EntityContext Returns: A tuple (node, deps): @@ -358,57 +297,26 @@ def node_to_graph(node, ctx, nocompile_decorators): """ # TODO(mdan): Verify arguments for correctness. - # TODO(mdan): Factor out common elements. - # These include: - # * code move between blocks - # * visiting blocks in transformers - - # Certain steps, especially canonicalization, insert new symbols into the - # tree, which must be accounted. Although less efficient, it is most robust - # to re-run the analysis. - - node = _static_analysis_pass(node, ctx) - - # TODO(mdan): Clean this up. - # Some intermediate analyses are not required, and some comments got orphaned. - - # TODO(mdan): We may assume all converters require analysis to be re-done. - + node = _apply_transformer(node, context, ifexp) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? - ctx.source_code = None - node = ifexp.transform(node, ctx) - node, deps = decorators.transform(node, nocompile_decorators) - node = break_statements.transform(node, ctx) - node = _static_analysis_pass(node, ctx) - - node = asserts.transform(node, ctx) - + context.info.source_code = None + node = _apply_transformer(node, context, decorators) + node = _apply_transformer(node, context, break_statements) + node = _apply_transformer(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_statements.transform(node, ctx) - ctx.namespace['len'] = len - - node = _static_analysis_pass(node, ctx) - node = single_return.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = lists.transform(node, ctx) - node = _static_analysis_pass(node, ctx) - node = slices.transform(node, ctx) - node = builtin_functions.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, - nocompile_decorators) - node = control_flow.transform(node, ctx) - - # control_flow may create new symbols and change scopes. - node = _static_analysis_pass(node, ctx) - node = logical_expressions.transform(node, ctx) - node = side_effect_guards.transform(node, ctx) - node = name_scopes.transform(node, ctx) - - return node, deps + node = _apply_transformer(node, context, continue_statements) + context.info.namespace['len'] = len + node = _apply_transformer(node, context, single_return) + node = _apply_transformer(node, context, lists) + node = _apply_transformer(node, context, slices) + node = _apply_transformer(node, context, builtin_functions) + node = _apply_transformer(node, context, call_trees) + node = _apply_transformer(node, context, control_flow) + node = _apply_transformer(node, context, logical_expressions) + node = _apply_transformer(node, context, side_effect_guards) + node = _apply_transformer(node, context, name_scopes) + return node diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index bc61498b5422f5e130bbfeef935d0a796b4f5922..f5279298afdcd406a9a6762e58367cea8ca63141 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op @@ -30,8 +32,13 @@ from tensorflow.python.platform import test class ConversionTest(test.TestCase): - def _simple_conversion_map(self): - return conversion.ConversionMap(True, (), (), api) + def _simple_program_ctx(self): + return converter.ProgramContext( + recursive=True, + autograph_decorators=(), + partial_types=(), + autograph_module=api, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) def test_is_whitelisted_for_graph(self): @@ -44,16 +51,16 @@ class ConversionTest(test.TestCase): def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph('dummy', conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph('dummy', program_ctx, None, None) def test_entity_to_graph_callable(self): b = 2 def f(a): return a + b - conversion_map = self._simple_conversion_map() - ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', name) self.assertTrue(ns['b'] is b) @@ -66,18 +73,17 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) - self.assertTrue(f in conversion_map.dependency_cache) - self.assertTrue(g in conversion_map.dependency_cache) - self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) + self.assertTrue(f in program_ctx.dependency_cache) + self.assertTrue(g in program_ctx.dependency_cache) + self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) # need the extra .body[0] in order to step past the with tf.name_scope('f') # that is added automatically self.assertEqual( - 'tf__g', - conversion_map.dependency_cache[f].body[0].body[0].value.func.id) - self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) + 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id) + self.assertEqual('tf__g', program_ctx.dependency_cache[g].name) def test_entity_to_graph_class_hierarchy(self): @@ -104,16 +110,15 @@ class ConversionTest(test.TestCase): def baz(self): return self.y - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestBase in conversion_map.dependency_cache) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertTrue(TestBase in program_ctx.dependency_cache) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertEqual('TfTestBase', - conversion_map.dependency_cache[TestBase].body[-1].name) - self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestBase].body[-1].name) + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass].body[-1].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -126,24 +131,23 @@ class ConversionTest(test.TestCase): def call(self, x): return 3 * x - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) - self.assertFalse(training.Model in conversion_map.dependency_cache) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) + self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', - conversion_map.dependency_cache[TestSubclass].body[0].names[0].name) - self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass].body[0].names[0].name) + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass].body[-1].name) def test_entity_to_graph_lambda(self): f = lambda a: a with self.assertRaises(NotImplementedError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) def test_ag_module_cached(self): def callee(): @@ -152,11 +156,11 @@ class ConversionTest(test.TestCase): def caller(a): return a() - conversion_map = self._simple_conversion_map() - _, _, callee_ns = conversion.entity_to_graph( - callee, conversion_map, None, None) - _, _, caller_ns = conversion.entity_to_graph( - caller, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None, + None) + _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None, + None) self.assertTrue(callee_ns['ag__'] is caller_ns['ag__']) diff --git a/tensorflow/contrib/autograph/impl/directives.py b/tensorflow/contrib/autograph/impl/directives.py deleted file mode 100644 index aabe5d99394a0cb921196d1c6a6b2a9496ea7545..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/impl/directives.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Directives are special no-op functions that serve as compilation markers. - -They provide static information like type hints, compilation and TensorFlow -overrides. - -These serve as annotations in the compiled code, allowing the user some control -over the compilation process. They have no functional role at runtime. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -UNSPECIFIED = object() - - -def set_element_type(entity, dtype, shape=UNSPECIFIED): - """Indicates that the entity is expected hold items of specified type/shape. - - The staged TensorFlow ops will reflect and assert this data type. Ignored - otherwise. - - Args: - entity: The entity to annotate. - dtype: TensorFlow dtype value to assert for entity. - shape: Optional shape to assert for entity. - """ - del entity - del dtype - del shape - - -def set_loop_options( - parallel_iterations=UNSPECIFIED, - back_prop=UNSPECIFIED, - swap_memory=UNSPECIFIED, - maximum_iterations=UNSPECIFIED): - """Specifies additional arguments to be passed to the enclosing while_loop. - - The parameters apply to and only to the immediately enclosing loop. It only - has effect if the loop is staged as a TF while_loop; otherwise the parameters - have no effect. - - Args: - parallel_iterations: See tf.while_loop. - back_prop: See tf.while_loop. - swap_memory: See tf.while_loop. - maximum_iterations: See tf.while_loop. - """ - del parallel_iterations - del back_prop - del swap_memory - del maximum_iterations diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/impl/special_functions.py deleted file mode 100644 index b7a8177c44c88217560fb7f72c77d3ac1aa0c9ec..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/impl/special_functions.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Special functions that only make sense for AutoGraph. - -These functions are meant to ensure feature parity between Python and AutoGraph, -so that the exact same code works in both modes. In general, AutoGraph will -replace these calls. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.operators import data_structures - - -def stack(list_or_tensor, element_dtype=None): - """Stacks the input, if it admits the notion of stacking. No-op otherwise. - - For example, a list of tensors can be stacked into a larger tensor. This - function is similar to tf.stack, but it accepts non-lists and lists of - non-tensors as arguments. In the latter case, the function does nothing. - - Args: - list_or_tensor: Any entity. - element_dtype: Optional dtype for the elements in the list. Required if the - input is stackable, and the list is untyped. - - Returns: - If the input is stackable, a new object representing the stacked inputs. - Otherwise it returns list_or_tensor unchanged. - """ - return data_structures.list_stack( - list_or_tensor, - data_structures.ListStackOpts( - element_dtype=element_dtype, original_call=lambda x: x)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/impl/special_functions_test.py deleted file mode 100644 index 9b52d2a59b5a3e3c92f11343197379c773ecc828..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/impl/special_functions_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for special_functions module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.impl import special_functions -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import list_ops -from tensorflow.python.platform import test - - -class SpecialFunctionsTest(test.TestCase): - - def test_basic(self): - self.assertEqual(special_functions.stack(1), 1) - self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) - # TODO(mdan): This should probably forward to tf.stack. - self.assertTrue( - isinstance( - special_functions.stack( - [constant_op.constant(1), - constant_op.constant(2)]), list)) - - t = constant_op.constant([1.0, 2.0]) - l = list_ops.tensor_list_from_tensor( - t, element_shape=constant_op.constant([], dtype=dtypes.int32)) - self.assertTrue( - tensor_util.is_tensor( - special_functions.stack(l, element_dtype=dtypes.float32))) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 0c6ab65505ee03e19588adae73d3134399a34b65..332d5dab19e7ade1531b564fbdef2fa0dc2d09d5 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -28,7 +28,15 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:list_ops", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_util", + "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 989b821e53a5cefbe39095e669f9a9e0bec65b8a..a49a4ed05ca99a5c9784cfc132784890e63a94de 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -22,8 +22,8 @@ py_library( "__init__.py", "anno.py", "ast_util.py", + "cfg.py", "compiler.py", - "context.py", "inspect_utils.py", "parser.py", "pretty_printer.py", @@ -38,6 +38,8 @@ py_library( "@gast_archive//:gast", "@six_archive//:six", "@termcolor_archive//:termcolor", + # TODO(mdan): Remove this dependency. + "//tensorflow/python:util", ], ) @@ -62,6 +64,17 @@ py_test( ], ) +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "compiler_test", srcs = ["compiler_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..666328781f683c9457f6892c0a26088c33ba94a7 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -0,0 +1,733 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph (CFG) structure for Python AST representation. + +The CFG is a digraph with edges representing valid control flow. Each +node is associated with exactly one AST node, but not all AST nodes may have +a corresponding CFG counterpart. + +Once built, the CFG itself is immutable, but the values it holds need not be; +they are usually annotated with information extracted by walking the graph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from enum import Enum + +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.pyct import compiler + + +class Node(object): + """A node in the CFG. + + Although new instances of this class are mutable, the objects that a user + finds in the CFG are typically not. + + The nodes represent edges in the CFG graph, and maintain pointers to allow + efficient walking in both forward and reverse order. The following property + holds for all nodes: "child in node.next" iff "node in child.prev". + + Attributes: + next: FrozenSet[Node, ...], the nodes that follow this node, in control + flow order + prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse + control flow order + ast_node: ast.AST, the AST node corresponding to this CFG node + """ + + def __init__(self, next_, prev, ast_node): + self.next = next_ + self.prev = prev + self.ast_node = ast_node + + def freeze(self): + self.next = frozenset(self.next) + self.prev = frozenset(self.prev) + + def __repr__(self): + return compiler.ast_to_source(self.ast_node).strip() + + +class Graph( + collections.namedtuple('Graph', ['entry', 'exit', 'error', 'index'])): + """A Control Flow Graph. + + The CFG maintains an index to allow looking up a CFG node by the AST node to + which it is associated. The index can also be enumerated in top-down, depth + first order. + + Walking the graph in forward or reverse order is supported by double + parent-child links. + + Note: the error nodes are not wired to their corresponding finally guards, + because these are shared, and wiring them would create a reverse path from + normal control flow into the error nodes, which we want to avoid. + + Attributes: + entry: Node, the entry node + exit: FrozenSet[Node, ...], the exit nodes + error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised + error (errors propagated from function calls are not accounted) + index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG + node + """ + + def __repr__(self): + result = 'digraph CFG {\n' + for node in self.index.values(): + result += ' %s [label="%s"];\n' % (id(node), node) + for node in self.index.values(): + if node.next: + result += ' %s -> {%s};\n' % (id(node), ', '.join( + repr(id(n)) for n in node.next)) + result += '}' + return result + + +class _WalkMode(Enum): + FORWARD = 1 + REVERSE = 2 + + +class GraphVisitor(object): + """Base class for a CFG visitors. + + This implementation is not thread safe. + + The visitor has some facilities to simplify dataflow analyses. In particular, + it allows revisiting the nodes at the decision of the subclass. This can be + used to visit the graph until the state reaches a fixed point. + + For more details on dataflow analysis, see + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Note: the literature generally suggests visiting successor nodes only when the + state of the current node changed, regardless of whether that successor has + ever been visited. This implementation visits every successor at least once. + + Attributes: + graph: Graph + in_: Dict[Node, Any], stores node-keyed state during a visit + out: Dict[Node, Any], stores node-keyed state during a visit + """ + + def reset(self): + self.in_ = { + node: self.init_state(node) for node in self.graph.index.values() + } + self.out = { + node: self.init_state(node) for node in self.graph.index.values() + } + + def init_state(self, node): + """State initialization function. Optional to overload. + + An in/out state slot will be created for each node in the graph. Subclasses + may overload this to control what that is initialized to. + + Args: + node: Node + """ + del node + return None + + def visit_node(self, node): + """Visitor function. + + Args: + node: Node + Returns: + bool, whether the node should be revisited; subclasses can visit every + reachable node exactly once by always returning False + """ + raise NotImplementedError('Subclasses must implement this.') + + def _visit_internal(self, mode): + """Visits the CFG, depth-first.""" + assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE) + if mode == _WalkMode.FORWARD: + open_ = [self.graph.entry] + elif mode == _WalkMode.REVERSE: + open_ = list(self.graph.exit) + closed = set() + self.reset() + + while open_: + node = open_.pop(0) + closed.add(node) + + should_revisit = self.visit_node(node) + + if mode == _WalkMode.FORWARD: + children = node.next + elif mode == _WalkMode.REVERSE: + children = node.prev + + for next_ in children: + if should_revisit or next_ not in closed: + open_.append(next_) + + def visit_forward(self, graph): + self.graph = graph + self._visit_internal(_WalkMode.FORWARD) + + def visit_reverse(self, graph): + self.graph = graph + self._visit_internal(_WalkMode.REVERSE) + + +class GraphBuilder(object): + """Builder that constructs a CFG from a given AST. + + This GraphBuilder facilitates constructing the DAG that forms the CFG when + nodes + are supplied in lexical order (i.e., top-down, depth first). Under these + conditions, it supports building patterns found in typical structured + programs. + + This builder ignores the flow generated by exceptions, which are assumed to + always be catastrophic and present purely for diagnostic purposes (e.g. to + print debug information). Statements like raise and try/catch sections are + allowed and will generate control flow edges, but ordinaty statements are + assumed not to raise exceptions. + + Finally sections are also correctly interleaved between break/continue/return + nodes and their subsequent statements. + + Important concepts: + * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly + * leaf set - since the graph is constructed gradually, a leaf set maintains + the CFG nodes that will precede the node that the builder expects to + receive next; when an ordinary node is added, it is connected to the + existing leaves and it in turn becomes the new leaf + * jump nodes - nodes that should generate edges other than what + ordinary nodes would; these correspond to break, continue and return + statements + * sections - logical delimiters for subgraphs that require special + edges; there are various types of nodes, each admitting various + types of jump nodes; sections are identified by their corresponding AST + node + """ + + # TODO(mdan): Perhaps detail this in a markdown doc. + # TODO(mdan): Add exception support. + + def __init__(self, parent_ast_node): + self.reset() + self.parent = parent_ast_node + + def reset(self): + """Resets the state of this factory.""" + self.head = None + self.errors = set() + self.node_index = collections.OrderedDict() + + # TODO(mdan): Too many primitives. Use classes. + self.leaves = set() + + self.finally_sections = {} + self.finally_section_subgraphs = {} # Values are [begin_node, exit_nodes] + # Whether the guard section can be reached from the statement that precedes + # it. + self.finally_section_has_direct_flow = {} + # Finally sections that await their first node. + self.pending_finally_sections = set() + + # Exit jumps keyed by the section they affect. + self.exits = {} + + # The entry of loop sections, keyed by the section. + self.section_entry = {} + # Continue jumps keyed by the section they affect. + self.continues = {} + + # The entry of conditional sections, keyed by the section. + self.cond_entry = {} + # Lists of leaf nodes corresponding to each branch in the section. + self.cond_leaves = {} + + def _connect_nodes(self, first, second): + """Connects nodes to signify that control flows from first to second. + + Args: + first: Union[Set[Node, ...], Node] + second: Node + """ + if isinstance(first, Node): + first.next.add(second) + second.prev.add(first) + else: + for node in first: + self._connect_nodes(node, second) + + def _add_new_node(self, ast_node): + """Grows the graph by adding a CFG node following the current leaves.""" + if ast_node is self.node_index: + raise ValueError('%s added twice' % ast_node) + node = Node(next_=set(), prev=set(), ast_node=ast_node) + self.node_index[ast_node] = node + + if self.head is None: + self.head = node + + for leaf in self.leaves: + self._connect_nodes(leaf, node) + + # If any finally section awaits its first node, populate it. + for section_id in self.pending_finally_sections: + self.finally_section_subgraphs[section_id][0] = node + self.pending_finally_sections = set() + + return node + + def add_ordinary_node(self, ast_node): + """Grows the graph by adding an ordinary CFG node. + + Ordinary nodes are followed by the next node, in lexical order, that is, + they become the new leaf set. + + Args: + ast_node: ast.AST + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set((node,)) + return node + + def _add_jump_node(self, ast_node, guards): + """Grows the graph by adding a jump node. + + Jump nodes are added to the current leaf set, and the leaf set becomes + empty. If the jump node is the last in a cond section, then it may be added + back to the leaf set by a separate mechanism. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections active for this node + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set() + # The guards themselves may not yet be complete, and will be wired later. + self.finally_sections[node] = guards + return node + + def _connect_jump_to_finally_sections(self, node): + """Connects a jump node to the finally sections protecting it.""" + cursor = set((node,)) + for guard_section_id in self.finally_sections[node]: + guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id] + self._connect_nodes(cursor, guard_begin) + cursor = guard_ends + del self.finally_sections[node] + # TODO(mdan): Should garbage-collect finally_section_subgraphs. + return cursor + + def add_exit_node(self, ast_node, section_id, guards): + """Grows the graph by adding an exit node. + + This node becomes an exit for the current section. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.exits[section_id].add(node) + + def add_continue_node(self, ast_node, section_id, guards): + """Grows the graph by adding a reentry node. + + This node causes control flow to go back to the loop section's entry. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.continues[section_id].add(node) + + def add_error_node(self, ast_node, guards): + """Grows the graph by adding an error node. + + This node becomes an exit for the entire graph. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.errors.add(node) + self.leaves = set() + + def enter_section(self, section_id): + """Enters a regular section. + + Regular sections admit exit jumps, which end the section. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_exit_node + """ + assert section_id not in self.exits + self.exits[section_id] = set() + + def exit_section(self, section_id): + """Exits a regular section.""" + + # Exits are jump nodes, which may be protected. + for exit_ in self.exits[section_id]: + self.leaves |= self._connect_jump_to_finally_sections(exit_) + + del self.exits[section_id] + + def enter_loop_section(self, section_id, entry_node): + """Enters a loop section. + + Loop sections define an entry node. The end of the section always flows back + to the entry node. These admit continue jump nodes which also flow to the + entry node. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_continue_node + entry_node: ast.AST, the entry node into the loop (e.g. the test node + for while loops) + """ + assert section_id not in self.section_entry + assert section_id not in self.continues + self.continues[section_id] = set() + node = self.add_ordinary_node(entry_node) + self.section_entry[section_id] = node + + def exit_loop_section(self, section_id): + """Exits a loop section.""" + self._connect_nodes(self.leaves, self.section_entry[section_id]) + + # continues are jump nodes, which may be protected. + for reentry in self.continues[section_id]: + guard_ends = self._connect_jump_to_finally_sections(reentry) + self._connect_nodes(guard_ends, self.section_entry[section_id]) + + # Loop nodes always loop back. + self.leaves = set((self.section_entry[section_id],)) + + del self.continues[section_id] + del self.section_entry[section_id] + + def enter_cond_section(self, section_id): + """Enters a conditional section. + + Conditional sections define an entry node, and one or more branches. + + Args: + section_id: Hashable, the same node that will be used in calls to the + section_id arg passed to new_cond_branch + """ + + assert section_id not in self.cond_entry + assert section_id not in self.cond_leaves + self.cond_leaves[section_id] = [] + + def new_cond_branch(self, section_id): + """Begins a new branch in a cond section.""" + assert section_id in self.cond_leaves + + if section_id in self.cond_entry: + # Subsequent splits move back to the split point, and memorize the + # current leaves. + self.cond_leaves[section_id].append(self.leaves) + self.leaves = self.cond_entry[section_id] + else: + # If this is the first time we split a section, just remember the split + # point. + self.cond_entry[section_id] = self.leaves + + def exit_cond_section(self, section_id): + """Exits a conditional section.""" + for split in self.cond_leaves[section_id]: + self.leaves |= split + del self.cond_entry[section_id] + del self.cond_leaves[section_id] + + def enter_finally_section(self, section_id): + """Enters a finally section.""" + # TODO(mdan): This, not the caller, should track the active sections. + self.finally_section_subgraphs[section_id] = [None, None] + if self.leaves: + self.finally_section_has_direct_flow[section_id] = True + else: + self.finally_section_has_direct_flow[section_id] = False + self.pending_finally_sections.add(section_id) + + def exit_finally_section(self, section_id): + """Exits a finally section.""" + assert section_id not in self.pending_finally_sections, 'Empty finally?' + self.finally_section_subgraphs[section_id][1] = self.leaves + # If the guard can only be reached by a jump, then it will not flow + # into the statement that follows it. + if not self.finally_section_has_direct_flow[section_id]: + self.leaves = set() + del self.finally_section_has_direct_flow[section_id] + + def build(self): + """Returns the CFG accumulated so far and resets the builder. + + Returns: + Graph + """ + # Freeze the nodes. + for node in self.node_index.values(): + node.freeze() + + result = Graph( + entry=self.head, + exit=self.leaves, + error=self.errors, + index=self.node_index) + + # Reset the state. + self.reset() + + return result + + +class AstToCfg(gast.NodeVisitor): + """Converts an AST to CFGs. + + A separate CFG will be constructed for each function. + """ + + # TODO(mdan): Figure out how to deal with closures. + + def __init__(self): + super(AstToCfg, self).__init__() + + self.builder_stack = [] + self.builder = None + self.cfgs = {} + + self.lexical_scopes = [] + + def _enter_lexical_scope(self, node): + self.lexical_scopes.append(node) + + def _exit_lexical_scope(self, node): + leaving_node = self.lexical_scopes.pop() + assert node == leaving_node + + def _get_enclosing_scopes(self, include, stop_at): + included = [] + for node in reversed(self.lexical_scopes): + if isinstance(node, include): + included.append(node) + if isinstance(node, stop_at): + return node, included + return None, included + + def _process_basic_statement(self, node): + self.generic_visit(node) + self.builder.add_ordinary_node(node) + + def _process_exit_statement(self, node, *exits_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(exits_nodes_of_type), + ) + if try_node is None: + raise ValueError( + '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type)) + self.builder.add_exit_node(node, try_node, guards) + + def _process_continue_statement(self, node, *loops_to_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(loops_to_nodes_of_type), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any of %s' % + (node, loops_to_nodes_of_type)) + self.builder.add_continue_node(node, try_node, guards) + + def visit_FunctionDef(self, node): + self.builder_stack.append(self.builder) + self.builder = GraphBuilder(node) + + self._enter_lexical_scope(node) + self.builder.enter_section(node) + + self._process_basic_statement(node.args) + for stmt in node.body: + self.visit(stmt) + + self.builder.exit_section(node) + self._exit_lexical_scope(node) + + self.cfgs[node] = self.builder.build() + self.builder = self.builder_stack.pop() + + def visit_Lambda(self, node): + # TODO(mdan): Treat like FunctionDef? That would be a separate CFG. + raise NotImplementedError() + + def visit_Return(self, node): + self._process_exit_statement(node, gast.FunctionDef) + + def visit_Expr(self, node): + self._process_basic_statement(node) + + def visit_Assign(self, node): + self._process_basic_statement(node) + + def visit_AnnAssign(self, node): + self._process_basic_statement(node) + + def visit_AugAssign(self, node): + self._process_basic_statement(node) + + def visit_Print(self, node): + self._process_basic_statement(node) + + def visit_Raise(self, node): + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=(gast.FunctionDef,), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any FunctionDef' % node) + self.builder.add_error_node(node, try_node, guards) + + def visit_Assert(self, node): + # Ignoring the effect of exceptions. + self._process_basic_statement(node) + + def visit_Delete(self, node): + self._process_basic_statement(node) + + def visit_If(self, node): + # No need to track ifs as lexical scopes, for now. + # Lexical scopes are generally tracked in order to be able to resolve the + # targets of jump statements like break/continue/etc. Since there is no + # statement that can interrupt a conditional, we don't need to track their + # lexical scope. That may change in the future. + + self.builder.enter_cond_section(node) + self._process_basic_statement(node.test) + + self.builder.new_cond_branch(node) + for stmt in node.body: + self.visit(stmt) + + self.builder.new_cond_branch(node) + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_cond_section(node) + + def visit_While(self, node): + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + self.builder.enter_loop_section(node, node.test) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # the statements inside it don't affect the loop itself. For example, a + # break in the loop's orelse will not affect the loop itself. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + + def visit_For(self, node): + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + # TODO(mdan): Strictly speaking, this should be node.target + node.iter. + # A blind dataflow analysis would have to process both node.target and + # node.iter to properly process read and write access. + self.builder.enter_loop_section(node, node.iter) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # they don't count as loop bodies. For example, a break in the loop's + # orelse will affect the parent loop, not the current one. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + + def visit_Break(self, node): + self._process_exit_statement(node, gast.While, gast.For) + + def visit_Continue(self, node): + self._process_continue_statement(node, gast.While, gast.For) + + def visit_Try(self, node): + self._enter_lexical_scope(node) + + for stmt in node.body: + self.visit(stmt) + # Unlike loops, the orelse is a simple continuation of the body. + for stmt in node.orelse: + self.visit(stmt) + + if node.handlers: + # TODO(mdan): Should we still support bare try/except? Might be confusing. + raise NotImplementedError('exceptions are not yet supported') + + self._exit_lexical_scope(node) + + self.builder.enter_finally_section(node) + for stmt in node.finalbody: + self.visit(stmt) + self.builder.exit_finally_section(node) + + def visit_With(self, node): + # TODO(mdan): Mark the context manager's exit call as exit guard. + self._process_basic_statement(node.items) + for stmt in node.body: + self.visit(stmt) + + +def build(node): + builder = AstToCfg() + builder.visit(node) + return builder.cfgs diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..00afadd5212a3aba8f25cd9a6f111d292635bbce --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg_test.py @@ -0,0 +1,790 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class CountingVisitor(cfg.GraphVisitor): + + def __init__(self): + self.counts = {} + + def visit_node(self, node): + self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1 + return False # visit only once + + +class GraphVisitorTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs, node + + def test_basic_coverage_forward(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor() + visitor.visit_forward(graph) + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + # The return node should be unreachable in forward direction. + self.assertTrue(fn_node.body[0].body[2] not in visitor.counts) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + def test_basic_coverage_reverse(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor() + visitor.visit_reverse(graph) + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + +class AstToCfgTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs + + def _repr_set(self, node_set): + return set(repr(n) for n in node_set) + + def _as_set(self, elements): + if elements is None: + return frozenset() + elif isinstance(elements, str): + return frozenset((elements,)) + else: + return frozenset(elements) + + def assertGraphMatches(self, graph, edges): + """Tests whether the CFG contains the specified edges.""" + for prev, node_repr, next_ in edges: + matched = False + for cfg_node in graph.index.values(): + if repr(cfg_node) == node_repr: + if (self._as_set(prev) == set(map(repr, cfg_node.prev)) and + self._as_set(next_) == set(map(repr, cfg_node.next))): + matched = True + break + if not matched: + self.fail( + 'match failed for node "%s" in graph:\n%s' % (node_repr, graph)) + + def test_straightline(self): + + def test_fn(a): + a += 1 + a = 2 + a = 3 + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'a += 1'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', 'return'), + ('a = 3', 'return', None), + ), + ) + + def test_straightline_no_return(self): + + def test_fn(a, b): + a = b + 1 + a += max(a) + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a, b', 'a = b + 1'), + ('a = b + 1', 'a += max(a)', None), + ), + ) + + def test_unreachable_code(self): + + def test_fn(a): + return + a += 1 # pylint:disable=unreachable + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'return'), + ('a', 'return', None), + (None, 'a += 1', None), + ), + ) + + def test_branch_straightline(self): + + def test_fn(a): + if a > 0: + a = 1 + else: + a += -1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('(a > 0)', 'a = 1', None), + ('(a > 0)', 'a += -1', None), + ), + ) + + def test_branch_nested(self): + + def test_fn(a): + if a > 0: + if a > 1: + a = 1 + else: + a = 2 + else: + if a > 2: + a = 3 + else: + a = 4 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', ('(a > 1)', '(a > 2)')), + ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', None), + ('(a > 1)', 'a = 2', None), + ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')), + ('(a > 2)', 'a = 3', None), + ('(a > 2)', 'a = 4', None), + ), + ) + + def test_branch_straightline_semi(self): + + def test_fn(a): + if a > 0: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', 'a = 1'), + ('(a > 0)', 'a = 1', None), + ), + ) + + def test_branch_return(self): + + def test_fn(a): + if a > 0: + return + else: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', ('return', 'a = 1')), + ('(a > 0)', 'a = 1', 'a = 2'), + ('(a > 0)', 'return', None), + ('a = 1', 'a = 2', None), + ), + ) + + def test_branch_return_minimal(self): + + def test_fn(a): + if a > 0: + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'return'), + ('(a > 0)', 'return', None), + ), + ) + + def test_while_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', None), + ), + ) + + def test_while_else_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_else_continue(self): + + def test_fn(a): + while a > 0: + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', '(a > 0)'), + ('a = 0', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_else_break(self): + + def test_fn(a): + while a > 0: + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + + def test_while_else_return(self): + + def test_fn(a): + while a > 0: + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_nested_straightline(self): + + def test_fn(a): + while a > 0: + while a > 1: + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_while_nested_continue(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 3: + continue + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')), + ('(a > 1)', '(a > 3)', ('continue', 'a = 1')), + ('(a > 3)', 'continue', '(a > 1)'), + ('(a > 3)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_while_nested_break(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 2: + break + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')), + ('(a > 1)', '(a > 2)', ('break', 'a = 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'a = 1', '(a > 1)'), + (('(a > 1)', 'break'), 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_for_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', None), + ), + ) + + def test_for_else_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_else_continue(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', 'range(0, a)'), + ('(a > 1)', 'a = 0', 'a = 1'), + ('a = 0', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_else_break(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + + def test_for_else_return(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_nested_straightline(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')), + ('range(1, a)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_for_nested_continue(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 3: + continue + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)', + ('(a > 3)', 'a = 2')), + ('range(1, a)', '(a > 3)', ('continue', 'b += 1')), + ('(a > 3)', 'continue', 'range(1, a)'), + ('(a > 3)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_for_nested_break(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 2: + break + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')), + ('range(1, a)', '(a > 2)', ('break', 'b += 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'b += 1', 'range(1, a)'), + (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_complex(self): + + def test_fn(a): + b = 0 + while a > 0: + for b in range(0, a): + if a > 2: + break + if a > 3: + if a > 4: + continue + else: + max(a) + break + b += 1 + else: # for b in range(0, a): + return a + a = 2 + for a in range(1, a): + return b + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')), + ( + ('(a > 0)', 'continue', 'b += 1'), + 'range(0, a)', + ('(a > 2)', 'return a'), + ), + ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')), + ('(a > 3)', '(a > 4)', ('continue', 'max(a)')), + ('(a > 4)', 'max(a)', 'break'), + ('max(a)', 'break', 'a = 2'), + ('(a > 4)', 'continue', 'range(0, a)'), + ('(a > 3)', 'b += 1', 'range(0, a)'), + ('range(0, a)', 'return a', None), + ('break', 'a = 2', '(a > 0)'), + ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')), + ('range(1, a)', 'return b', None), + ('range(1, a)', 'a = 3', None), + ), + ) + + def test_finally_straightline(self): + + def test_fn(a): + try: + a += 1 + finally: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'a += 1', 'a = 2'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_return_finally(self): + + def test_fn(a): + try: + return a + finally: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'return a', 'a = 1'), + ('return a', 'a = 1', None), + (None, 'a = 2', None), + ), + ) + + def test_break_finally(self): + + def test_fn(a): + while a > 0: + try: + break + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'break'), + ('(a > 0)', 'break', 'a = 1'), + ('break', 'a = 1', None), + ), + ) + + def test_continue_finally(self): + + def test_fn(a): + while a > 0: + try: + continue + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', 'continue'), + ('(a > 0)', 'continue', 'a = 1'), + ('continue', 'a = 1', '(a > 0)'), + ), + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ca1441cf6f8bb034c95b37fcdd9e8158d1db2e39 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD @@ -0,0 +1,38 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "common_transformers", + srcs = [ + "anf.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "@gast_archive//:gast", + ], +) + +py_test( + name = "anf_test", + srcs = ["anf_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":common_transformers", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py new file mode 100644 index 0000000000000000000000000000000000000000..cc039986c219db1febfe610a5078e26eeb2d5a83 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion to A-normal form.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import transformer + + +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + +class AnfTransformer(transformer.Base): + """Performs the actual conversion.""" + + # TODO(mdan): Link to a reference. + # TODO(mdan): Implement. + + def __init__(self, entity_info): + """Creates a transformer. + + Args: + entity_info: transformer.EntityInfo + """ + super(AnfTransformer, self).__init__(entity_info) + self._gensym = DummyGensym(entity_info) + + +def transform(node, entity_info): + return AnfTransformer(entity_info).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py new file mode 100644 index 0000000000000000000000000000000000000000..81983a5ecb7b8c6216285409f854e27b7154a08b --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for anf module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.common_transformers import anf +from tensorflow.python.platform import test + + +class AnfTransformerTest(test.TestCase): + + def _simple_source_info(self): + return transformer.EntityInfo( + source_code=None, + source_file=None, + namespace=None, + arg_values=None, + arg_types=None, + owner_type=None) + + def test_basic(self): + + def test_function(): + a = 0 + return a + + node, _ = parser.parse_entity(test_function) + node = anf.transform(node, self._simple_source_info()) + result, _ = compiler.ast_to_object(node) + + self.assertEqual(test_function(), result.test_function()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py deleted file mode 100644 index b34015cfd2888f0dbeb6492b9e7335d561bf4763..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/context.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Conversion context containers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -class EntityContext(object): - """Contains information about an entity, like source code. - - In general, objects of this class should be considered immutable. - - Attributes: - namer: Namer that matches the contract of all converters. - source_code: The entity's source code. - source_file: The entity's source file. - namespace: Dict[str->*], containing symbols visible to the entity - (excluding parameters). - arg_values: Dict[str->*], containing parameter values, if known. - arg_types: Dict[str->*], containing parameter types, if known. - owner_type: The surrounding class type of the function, if present. - """ - - # TODO(mdan): Remove the default and update tests. - def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, owner_type, recursive, type_annotation_func=None): - self.namer = namer - self.source_code = source_code - self.source_file = source_file - self.namespace = namespace - self.arg_values = {} if arg_values is None else arg_values - self.arg_types = {} if arg_types is None else arg_types - self.owner_type = owner_type - self.recursive = recursive - self.type_annotation_func = type_annotation_func diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 8064a967cd389e88d3febbeb21cac87b0fef9e18..bcf2dacec2062704805f1d72ec27a243159d13c1 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -27,6 +27,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", ], ) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index fdbd349af9d3325af114a7206d89617134278f14..bc22be0a270bbc9c361aea6d6d9c255ea51796e8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.qual_names import QN from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -112,18 +112,16 @@ class ActivityAnalyzerTest(test.TestCase): def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - return node, ctx + node = activity.resolve(node, entity_info) + return node, entity_info def test_local_markers(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py index ce746feeacf373874f9852d430eb37fadaf1e89e..4acc4ed66a62b0ccd407d39b1abda00c4c88a9a1 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -276,9 +276,9 @@ class Forward(object): taken). """ - def __init__(self, label, context, transfer_fn=operator.or_): + def __init__(self, label, source_info, transfer_fn=operator.or_): self.transfer_fn = transfer_fn - self.context = context + self.source_info = source_info self.out_label = label + '_out' self.in_label = label + '_in' self.gen_label = label + '_gen' @@ -286,7 +286,7 @@ class Forward(object): # TODO(alexbw): see if we can simplify by visiting breadth-first def visit(self, node): - """Depth-first walking the CFG, applying dataflow information propagation.""" + """Depth-first walking the CFG, applying dataflow info propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return @@ -399,18 +399,18 @@ class Liveness(Backward): later in the program. """ - def __init__(self, context): - super(Liveness, self).__init__('live', context) + def __init__(self, source_info): + super(Liveness, self).__init__('live', source_info) def get_gen_kill(self, node, _): # A variable's parents are live if it is live # e.g. x is live if x.y is live. This means gen needs to return # all parents of a variable (if it's an Attribute or Subscript). # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) - gen = activity.get_read(node.value, self.context) + gen = activity.get_read(node.value, self.source_info) gen = functools.reduce(lambda left, right: left | right.support_set, gen, gen) - kill = activity.get_updated(node.value, self.context) + kill = activity.get_updated(node.value, self.source_info) return gen, kill @@ -420,11 +420,11 @@ class ReachingDefinitions(Forward): Each statement is annotated with a set of (variable, definition) pairs. """ - def __init__(self, context): - super(ReachingDefinitions, self).__init__('definitions', context) + def __init__(self, source_info): + super(ReachingDefinitions, self).__init__('definitions', source_info) def get_gen_kill(self, node, incoming): - definitions = activity.get_updated(node.value, self.context) + definitions = activity.get_updated(node.value, self.source_info) gen = frozenset((id_, node.value) for id_ in definitions) kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) return gen, kill @@ -437,9 +437,10 @@ class Defined(Forward): be defined at that point. """ - def __init__(self, context): - super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + def __init__(self, source_info): + super(Defined, self).__init__( + 'defined', source_info, transfer_fn=operator.and_) def get_gen_kill(self, node, _): - gen = activity.get_updated(node.value, self.context) + gen = activity.get_updated(node.value, self.source_info) return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py index fc07fa3447b23c0595a5893329de8a2d7055ca15..428ebbedca85f9b94b4b1db0f3b36a334126196b 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -23,29 +23,26 @@ import functools import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.python.platform import test class CFGTest(test.TestCase): - def _parse_and_analyze(self, test_fn, namespace, arg_types=None): - arg_types = arg_types or {} + def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, - namespace=namespace, + namespace={}, arg_values=None, - arg_types=arg_types, - owner_type=None, - recursive=True) + arg_types=None, + owner_type=None) node = qual_names.resolve(node) - return node, ctx + return node, entity_info def _check_anno_matches(self, node, anno_name, var_names): if isinstance(var_names, str): @@ -73,7 +70,7 @@ class CFGTest(test.TestCase): x = x return x - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) body = node.body[0].body # Only the argument reaches the expression @@ -106,7 +103,7 @@ class CFGTest(test.TestCase): y = 2 # pylint: disable=unused-variable return x - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body # only x is for sure defined at the end @@ -116,7 +113,7 @@ class CFGTest(test.TestCase): self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) def _get_live_annotated_fnbody(self, f): - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Liveness(ctx)) body = node.body[0].body return body @@ -226,7 +223,7 @@ class CFGTest(test.TestCase): return g(x) - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body @@ -253,7 +250,7 @@ class CFGTest(test.TestCase): return g() # y is not defined here - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body self.assertEqual( @@ -282,7 +279,7 @@ class CFGTest(test.TestCase): return x, y for f in (for_orelse, while_orelse): - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) body = node.body[0].body return_node = body[-1] diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 53ae15459097baff918432a493edd7360ebf209d..9ccb98f79adbe5410a7554548ee75ab95345962d 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -39,7 +39,7 @@ class LiveValueResolver(transformer.Base): def visit_ClassDef(self, node): self.generic_visit(node) - anno.setanno(node, 'live_val', self.context.namespace[node.name]) + anno.setanno(node, 'live_val', self.entity_info.namespace[node.name]) return node def visit_Name(self, node): @@ -55,8 +55,8 @@ class LiveValueResolver(transformer.Base): if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - elif node.id in self.context.namespace: - obj = self.context.namespace[node.id] + elif node.id in self.entity_info.namespace: + obj = self.entity_info.namespace[node.id] anno.setanno(node, 'live_val', obj) if hasattr(obj, '__name__'): anno.setanno(node, 'fqn', (obj.__name__,)) @@ -80,8 +80,8 @@ class LiveValueResolver(transformer.Base): # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: - if node.id in self.context.arg_values: - obj = self.context.arg_values[node.id] + if node.id in self.entity_info.arg_values: + obj = self.entity_info.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 69e428bde109ed43c3cdda1a94970a832dc47852..38af79277779f77ffe31c2f6e26ae88f3e1a7ae9 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import six from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info @@ -39,22 +39,19 @@ class LiveValuesResolverTest(test.TestCase): literals=None, arg_types=None): literals = literals or {} - arg_types = arg_types or {} node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) return node def test_literals(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index 7d1e65c958d7787ef5ed707d4822d14a83092975..a229c288a83e516fc02f3af8df2046c5365e569c 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -43,6 +43,7 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer @@ -52,6 +53,7 @@ from tensorflow.python.util import tf_inspect # TODO(mdan): Remove the duplication between this and activity.py. # In particular, the symbol definitions we track here could as well be tracked # there because they follow the same rules for visibility. +# TODO(mdan): Use a CFG based Defined analysis instead. class Scope(object): """Tracks symbol value references. @@ -135,35 +137,40 @@ class TypeInfoResolver(transformer.Base): node.orelse = self._visit_block(node.orelse) return node - def _process_function_arg(self, arg_name): - str_name = str(arg_name) - type_holder = arg_name.ast() - self.scope.setval(arg_name, type_holder) - if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: + def _process_function_arg(self, arg_node): + qn = anno.getanno(arg_node, anno.Basic.QN) + arg_name = str(qn) + self.scope.setval(qn, arg_node) + if (len(self.enclosing_entities) == 1 and + arg_name in self.entity_info.arg_types): # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_string, type_obj = self.context.arg_types[str_name] - anno.setanno(type_holder, 'type', type_obj) - anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) + type_string, type_obj = self.entity_info.arg_types[arg_name] + anno.setanno(arg_node, 'type', type_obj) + anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.'))) def visit_arg(self, node): - self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) + self._process_function_arg(node.arg) return node def visit_Name(self, node): self.generic_visit(node) - qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): - self._process_function_arg(qn) - elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): - # E.g. if we had - # a = b - # then for future references to `a` we should have definition = `b` - definition = self.scope.getval(qn) - anno.copyanno(definition, node, 'type') - anno.copyanno(definition, node, 'type_fqn') - anno.copyanno(definition, node, 'element_type') - anno.copyanno(definition, node, 'element_shape') + self._process_function_arg(node) + elif isinstance(node.ctx, gast.Load): + qn = anno.getanno(node, anno.Basic.QN) + if self.scope.hasval(qn): + # E.g. if we had + # a = b + # then for future references to `a` we should have definition = `b` + definition = self.scope.getval(qn) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + anno.setanno(node, 'definition', definition) + + # TODO(mdan): Remove this when the directives module is in. + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') return node def _process_variable_assignment(self, target, value): @@ -203,12 +210,12 @@ class TypeInfoResolver(transformer.Base): node.targets, node.value, self._process_variable_assignment) return node + # TODO(mdan): Remove as soon as the new directives module is ready. def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. - if (anno.getanno(node.func, 'live_val') is - self.context.type_annotation_func): + if anno.getanno(node.func, 'live_val') is utils.set_element_type: if len(node.args) < 2 or len(node.args) > 3: raise ValueError('"%s" must have either two or three parameters' @@ -219,8 +226,8 @@ class TypeInfoResolver(transformer.Base): else: target_arg, type_arg, shape_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): - raise ValueError('the first argument of "%s" must by a symbol' - % self.context.type_annotation_func) + raise ValueError('the first argument of "%s" must by a symbol' % + utils.set_element_type) # TODO(mdan): This is vulnerable to symbol renaming. element_type = type_arg element_shape = shape_arg diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 484562f294bb53a63feeca965b8f94c58aa2a685..32b1148ab21809514bc09a31e26f0219017bd088 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info @@ -62,21 +61,18 @@ class TypeInfoResolverTest(test.TestCase): namespace, arg_types=None): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True, - type_annotation_func=utils.set_element_type) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) return node def test_constructor_detection(self): @@ -147,7 +143,7 @@ class TypeInfoResolverTest(test.TestCase): opt.minimize(0) node = self._parse_and_analyze( - test_fn, {'training': training}, + test_fn, {}, arg_types={ 'opt': (training.GradientDescentOptimizer.__name__, training.GradientDescentOptimizer) @@ -180,35 +176,6 @@ class TypeInfoResolverTest(test.TestCase): method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) - def test_type_annotation(self): - - class Foo(object): - pass - - def test_fn(): - f = [] - f = utils.set_element_type(f, Foo, (1, 2, 3)) - return f - - node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) - f_def = node.body[0].body[0].value - self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo') - f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') - - def test_type_annotation_args(self): - - class Foo(object): - pass - - def test_fn(f): - utils.set_element_type(f, Foo) - return f - - node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) - f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') - def test_nested_unpacking(self): class Foo(object): @@ -230,25 +197,6 @@ class TypeInfoResolverTest(test.TestCase): self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) - def test_inner_scope(self): - - def test_fn(): - a = [] - utils.set_element_type(a, 1) - for _ in a: - b = [] - utils.set_element_type(b, 2) - return a, b - - node = self._parse_and_analyze(test_fn, {'utils': utils}) - a, b = node.body[0].body[2].body[2].value.elts - self.assertEquals(anno.getanno(a, 'element_type').n, 1) - self.assertEquals(anno.getanno(b, 'element_type').n, 2) - self.assertFalse(anno.hasanno(a, 'type')) - self.assertFalse(anno.hasanno(b, 'type')) - self.assertFalse(anno.hasanno(a, 'live_val')) - self.assertFalse(anno.hasanno(b, 'live_val')) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index a656e99d21c6d3a1af831d3b34cf135b03c7ba29..76558118308c31a2c1a770cad814e96abd6a6063 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -32,15 +32,40 @@ class AutographParseError(SyntaxError): pass -def try_ast_to_source(node): - try: - return compiler.ast_to_source(node) - except AssertionError: - return '' +# TODO(mdan): Use namedtuple. +class EntityInfo(object): + """Contains information about a Python entity. Immutable. + + Examples of entities include functions and classes. + + Attributes: + source_code: The entity's source code. + source_file: The entity's source file. + namespace: Dict[str, ], containing symbols visible to the entity + (excluding parameters). + arg_values: dict[str->*], containing parameter values, if known. + arg_types: dict[str->*], containing parameter types, if known. + owner_type: The surrounding class type of the function, if present. + """ + + # TODO(mdan): Remove the default and update tests. + def __init__(self, source_code, source_file, namespace, arg_values, arg_types, + owner_type): + self.source_code = source_code + self.source_file = source_file + self.namespace = namespace + self.arg_values = {} if arg_values is None else arg_values + self.arg_types = {} if arg_types is None else arg_types + self.owner_type = owner_type class Base(gast.NodeTransformer): - """Base class for specialized transformers. + """Base class for general-purpose code transformers transformers. + + This is an extension of ast.NodeTransformer that provides a few additional + functions, like state tracking within the scope of arbitrary node, helpers + for processing code blocks, debugging, mapping of transformed code to + original code, and others. Scope-local state tracking: to keep state across nodes, at the level of (possibly nested) scopes, use enter/exit_local_scope and set/get_local. @@ -48,15 +73,17 @@ class Base(gast.NodeTransformer): when they are not properly paired. """ - def __init__(self, context): + # TODO(mdan): Document all extra features. + + def __init__(self, entity_info): """Initialize the transformer. Subclasses should call this. Args: - context: An EntityContext. + entity_info: An EntityInfo object. """ self._lineno = 0 self._col_offset = 0 - self.context = context + self.entity_info = entity_info self._enclosing_entities = [] # A stack that allows keeping mutable, scope-local state where scopes may be @@ -237,9 +264,15 @@ class Base(gast.NodeTransformer): # TODO(mdan): Look into allowing to rewrite the AST here. apply_fn(target, values) + def _get_source(self, node): + try: + return compiler.ast_to_source(node) + except AssertionError: + return '' + def visit(self, node): - source_code = self.context.source_code - source_file = self.context.source_file + source_code = self.entity_info.source_code + source_file = self.entity_info.source_file did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) @@ -275,7 +308,7 @@ class Base(gast.NodeTransformer): except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( - e.__class__.__name__, str(e), try_ast_to_source(node), + e.__class__.__name__, str(e), self._get_source(node), pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f110e79605945e908e8a49112cf758ec29fa1b11..baf04653ae862b0159fb50a1c67fa675ceb74b9a 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -29,16 +28,14 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): - def _context_for_testing(self): - return context.EntityContext( - namer=None, + def _simple_source_info(self): + return transformer.EntityInfo( source_code=None, source_file=None, namespace=None, arg_values=None, arg_types=None, - owner_type=None, - recursive=False) + owner_type=None) def test_entity_scope_tracking(self): @@ -55,7 +52,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 @@ -118,7 +115,7 @@ class TransformerTest(test.TestCase): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def test_function(a): """Docstring.""" @@ -157,7 +154,7 @@ class TransformerTest(test.TestCase): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def no_exit(a): if a > 0: @@ -196,7 +193,7 @@ class TransformerTest(test.TestCase): z = y return z - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) node = tr.visit(node) diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 012a51f71101471850d312033c41dcbc4805d44c..47b80bdf4ad88ebce3603a14ea2aa3cbe5bd345f 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -119,10 +119,6 @@ def batch_function(num_batch_threads, raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) - for inp in computation.captured_inputs: - print("inp: %s" % inp) - for op in inp.consumers(): - print("op: %s" % op) return gen_batch_ops.batch_function( num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index d9e23646d8334014f1bef0d0744df9310b59909f..9e6a146f67796466202cc5074ddd25e4c2b083a6 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import gamma as gamma_lib from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test @@ -256,50 +255,6 @@ class ExpectationTest(test.TestCase): gradq_approx_kl_normal_normal_, rtol=0.01, atol=0.) - def test_docstring_example_gamma(self): - with self.test_session() as sess: - num_draws = int(1e5) - concentration_p = constant_op.constant(1.) - concentration_q = constant_op.constant(2.) - p = gamma_lib.Gamma(concentration=concentration_p, rate=1.) - q = gamma_lib.Gamma(concentration=concentration_q, rate=3.) - approx_kl_gamma_gamma = monte_carlo_lib.expectation( - f=lambda x: p.log_prob(x) - q.log_prob(x), - samples=p.sample(num_draws, seed=42), - log_prob=p.log_prob, - use_reparametrization=(p.reparameterization_type - == distribution_lib.FULLY_REPARAMETERIZED)) - exact_kl_gamma_gamma = kullback_leibler.kl_divergence(p, q) - [exact_kl_gamma_gamma_, approx_kl_gamma_gamma_] = sess.run([ - exact_kl_gamma_gamma, approx_kl_gamma_gamma]) - self.assertEqual( - False, - p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED) - self.assertAllClose(exact_kl_gamma_gamma_, approx_kl_gamma_gamma_, - rtol=0.01, atol=0.) - - # Compare gradients. (Not present in `docstring`.) - gradp = lambda fp: gradients_impl.gradients(fp, concentration_p)[0] - gradq = lambda fq: gradients_impl.gradients(fq, concentration_q)[0] - [ - gradp_exact_kl_gamma_gamma_, - gradq_exact_kl_gamma_gamma_, - gradp_approx_kl_gamma_gamma_, - gradq_approx_kl_gamma_gamma_, - ] = sess.run([ - gradp(exact_kl_gamma_gamma), - gradq(exact_kl_gamma_gamma), - gradp(approx_kl_gamma_gamma), - gradq(approx_kl_gamma_gamma), - ]) - # Notice that variance (i.e., `rtol`) is higher when using score-trick. - self.assertAllClose(gradp_exact_kl_gamma_gamma_, - gradp_approx_kl_gamma_gamma_, - rtol=0.05, atol=0.) - self.assertAllClose(gradq_exact_kl_gamma_gamma_, - gradq_approx_kl_gamma_gamma_, - rtol=0.03, atol=0.) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 032b859d469ee5039e08e4af4c2f4ebf35c2ff19..68ead2f7609ca987180fe8973cf902f1e56b8388 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -192,7 +192,7 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). + r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..5c15d21e35557ba5ff25d9d943aae2809eddba4a --- /dev/null +++ b/tensorflow/contrib/bigtable/BUILD @@ -0,0 +1,196 @@ +# Cloud Bigtable client for TensorFlow + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_cc_test", + "tf_py_test", +) + +tf_custom_op_py_library( + name = "bigtable", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + dso = [ + ":python/ops/_bigtable.so", + ], + kernels = [ + ":bigtable_kernels", + ":bigtable_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/data", + ], +) + +tf_custom_op_library( + name = "python/ops/_bigtable.so", + srcs = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", + "ops/bigtable_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_ops", + deps = [":bigtable_ops_op_lib"], +) + +tf_gen_op_libs( + op_lib_names = [ + "bigtable_ops", + "bigtable_test_ops", + ], +) + +tf_kernel_library( + name = "bigtable_kernels", + srcs = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", + ], + deps = [ + ":bigtable_lib_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +# A library for use in the bigtable kernels. +cc_library( + name = "bigtable_lib_cc", + srcs = ["kernels/bigtable_lib.cc"], + hdrs = ["kernels/bigtable_lib.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +cc_library( + name = "bigtable_test_client", + srcs = ["kernels/test_kernels/bigtable_test_client.cc"], + hdrs = ["kernels/test_kernels/bigtable_test_client.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "@com_github_googleapis_googleapis//:bigtable_protos", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + "@com_googlesource_code_re2//:re2", + ], +) + +tf_cc_test( + name = "bigtable_test_client_test", + srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"], + tags = ["manual"], + deps = [ + ":bigtable_test_client", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_test_ops", + deps = [":bigtable_test_ops_op_lib"], +) + +tf_custom_op_library( + name = "python/kernel_tests/_bigtable_test.so", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + "ops/bigtable_test_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "@com_googlesource_code_re2//:re2", + ], +) + +# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h +cc_library( + name = "bigtable_test_kernels", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_googlesource_code_re2//:re2", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "bigtable_test_py", + dso = [ + ":python/kernel_tests/_bigtable_test.so", + ], + kernels = [ + ":bigtable_test_kernels", + ":bigtable_test_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_test_ops", + # "//tensorflow/contrib/util:util_py", + # "//tensorflow/python:framework_for_generated_wrappers", + # "//tensorflow/python:platform", + # "//tensorflow/python:util", + # "//tensorflow/python/data", + ], +) + +tf_py_test( + name = "bigtable_ops_test", + size = "small", + srcs = ["python/kernel_tests/bigtable_ops_test.py"], + additional_deps = [ + ":bigtable", + ":bigtable_test_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ef3c60069e8a97f7a13457156d20f3f7a4f7eccb --- /dev/null +++ b/tensorflow/contrib/bigtable/README.md @@ -0,0 +1,10 @@ +# Bigtable # + +[Google Cloud Bigtable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + + diff --git a/tensorflow/contrib/bigtable/__init__.py b/tensorflow/contrib/bigtable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7df054637cdab32f2dd6201dd3488a90495e1cf5 --- /dev/null +++ b/tensorflow/contrib/bigtable/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cloud Bigtable Client for TensorFlow. + +This contrib package allows TensorFlow to interface directly with Cloud Bigtable +for high-speed data loading. + +@@BigtableClient +@@BigTable + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'BigTable', + 'BigtableClient', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c81951d56ec491d7088dcb6f417c5351c0a2941 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -0,0 +1,313 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +namespace { + +class BigtableClientOp : public OpKernel { + public: + explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_)); + OP_REQUIRES(ctx, !project_id_.empty(), + errors::InvalidArgument("project_id must be non-empty")); + OP_REQUIRES(ctx, !instance_id_.empty(), + errors::InvalidArgument("instance_id must be non-empty")); + } + + ~BigtableClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr client = + bigtable::CreateDefaultDataClient( + project_id_, instance_id_, + bigtable::ClientOptions()); + *ret = new BigtableClientResource( + project_id_, instance_id_, std::move(client)); + return Status::OK(); + })); + core::ScopedUnref resource_cleanup(resource); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string project_id_; + string instance_id_; + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU), + BigtableClientOp); + +class BigtableTableOp : public OpKernel { + public: + explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_)); + OP_REQUIRES(ctx, !table_.empty(), + errors::InvalidArgument("table_name must be non-empty")); + } + + ~BigtableTableOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + + BigtableClientResource* client_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); + core::ScopedUnref unref_client(client_resource); + + BigtableTableResource* resource; + OP_REQUIRES_OK( + ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, client_resource](BigtableTableResource** ret) { + *ret = new BigtableTableResource(client_resource, table_); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string table_; // Note: this is const after construction. + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), + BigtableTableOp); + +class ToBigtableOp : public AsyncOpKernel { + public: + explicit ToBigtableOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())), + /* num_threads = */ 1, /* low_latency_hint = */ false)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([this, ctx, done]() { + const Tensor* column_families_tensor; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("column_families", &column_families_tensor), done); + OP_REQUIRES_ASYNC( + ctx, column_families_tensor->dims() == 1, + errors::InvalidArgument("`column_families` must be a vector."), done); + + const Tensor* columns_tensor; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done); + OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1, + errors::InvalidArgument("`columns` must be a vector."), + done); + OP_REQUIRES_ASYNC( + ctx, + columns_tensor->NumElements() == + column_families_tensor->NumElements(), + errors::InvalidArgument("len(column_families) != len(columns)"), + done); + + std::vector column_families; + column_families.reserve(column_families_tensor->NumElements()); + std::vector columns; + columns.reserve(column_families_tensor->NumElements()); + for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { + column_families.push_back(column_families_tensor->flat()(i)); + columns.push_back(columns_tensor->flat()(i)); + } + + DatasetBase* dataset; + OP_REQUIRES_OK_ASYNC( + ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); + + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + OP_REQUIRES_OK_ASYNC( + ctx, + dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator), + done); + + int64 timestamp_int; + OP_REQUIRES_OK_ASYNC( + ctx, ParseScalarArgument(ctx, "timestamp", ×tamp_int), + done); + OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1, + errors::InvalidArgument("timestamp must be >= -1"), + done); + std::chrono::milliseconds timestamp(timestamp_int); + + BigtableTableResource* resource; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); + core::ScopedUnref resource_cleanup(resource); + + std::vector components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence = false; + do { + ::bigtable::BulkMutation mutation; + // TODO(saeta): Make # of mutations configurable. + for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + if (!end_of_sequence) { + OP_REQUIRES_OK_ASYNC( + ctx, + CreateMutation(std::move(components), column_families, columns, + timestamp, &mutation), + done); + } + components.clear(); + } + grpc::Status mutation_status; + std::vector<::bigtable::FailedMutation> failures = + resource->table().BulkApply(std::move(mutation), mutation_status); + if (!failures.empty()) { + for (const auto& failure : failures) { + LOG(ERROR) << "Failure applying mutation on row (" + << failure.original_index() + << "): " << failure.mutation().row_key() + << " - error: " << failure.status().error_message() + << " (Details: " << failure.status().error_details() + << ")."; + } + } + OP_REQUIRES_ASYNC( + ctx, failures.empty() && mutation_status.ok(), + errors::Unknown("Failure while writing to BigTable: ", + mutation_status.error_code(), " - ", + mutation_status.error_message(), " (", + mutation_status.error_details(), + "), # of mutation failures: ", failures.size(), + ". See the log for the specific error details."), + done); + } while (!end_of_sequence); + done(); + }); + } + + private: + static string SanitizeThreadSuffix(string suffix) { + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; + } + + Status CreateMutation(std::vector tensors, + const std::vector& column_families, + const std::vector& columns, + std::chrono::milliseconds timestamp, + ::bigtable::BulkMutation* bulk_mutation) { + if (tensors.size() != column_families.size() + 1) { + return errors::InvalidArgument( + "Iterator produced a set of Tensors shorter than expected"); + } + ::bigtable::SingleRowMutation mutation( + std::move(tensors[0].scalar()())); + for (size_t i = 1; i < tensors.size(); ++i) { + if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { + return errors::Internal("Output tensor ", i, " was not a scalar"); + } + mutation.emplace_back( + ::bigtable::SetCell(column_families[i - 1], columns[i - 1], timestamp, + std::move(tensors[i].scalar()()))); + } + bulk_mutation->emplace_back(std::move(mutation)); + return Status::OK(); + } + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } + + std::unique_ptr thread_pool_; +}; + +REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), + ToBigtableOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..2514575f30831bdcfab87eba07511fd309e8b1c2 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status) { + if (status.ok()) { + return Status::OK(); + } + auto grpc_code = status.error_code(); + if (status.error_code() == ::grpc::StatusCode::ABORTED || + status.error_code() == ::grpc::StatusCode::UNAVAILABLE || + status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) { + grpc_code = ::grpc::StatusCode::INTERNAL; + } + return Status( + static_cast<::tensorflow::error::Code>(status.error_code()), + strings::StrCat("Error reading from BigTable: ", status.error_message(), + " (Details: ", status.error_details(), ")")); +} + +string RegexFromStringSet(const std::vector& strs) { + CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty."; + std::unordered_set uniq(strs.begin(), strs.end()); + if (uniq.size() == 1) { + return *uniq.begin(); + } + return str_util::Join(uniq, "|"); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..54303cdc5ed227f8d831f4b341fb9e4e8b83bdd6 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ + +// Note: we use bigtable/client/internal/table.h as this is the no-exception API + +#include "google/cloud/bigtable/data_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status); + +string RegexFromStringSet(const std::vector& strs); + +class BigtableClientResource : public ResourceBase { + public: + BigtableClientResource(string project_id, string instance_id, + std::shared_ptr client) + : project_id_(std::move(project_id)), + instance_id_(std::move(instance_id)), + client_(std::move(client)) {} + + std::shared_ptr get_client() { return client_; } + + string DebugString() override { + return strings::StrCat("BigtableClientResource(project_id: ", project_id_, + ", instance_id: ", instance_id_, ")"); + } + + private: + const string project_id_; + const string instance_id_; + std::shared_ptr client_; +}; + +class BigtableTableResource : public ResourceBase { + public: + BigtableTableResource(BigtableClientResource* client, string table_name) + : client_(client), + table_name_(std::move(table_name)), + table_(client->get_client(), table_name_) { + client_->Ref(); + } + + ~BigtableTableResource() override { client_->Unref(); } + + ::bigtable::noex::Table& table() { return table_; } + + string DebugString() override { + return strings::StrCat( + "BigtableTableResource(client: ", client_->DebugString(), + ", table: ", table_name_, ")"); + } + + private: + BigtableClientResource* client_; // Ownes one ref. + const string table_name_; + ::bigtable::noex::Table table_; +}; + +// BigtableReaderDatasetIterator is an abstract class for iterators from +// datasets that are "readers" (source datasets, not transformation datasets) +// that read from Bigtable. +template +class BigtableReaderDatasetIterator : public DatasetIterator { + public: + explicit BigtableReaderDatasetIterator( + const typename DatasetIterator::Params& params) + : DatasetIterator(params), iterator_(nullptr, false) {} + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureIteratorInitialized()); + if (iterator_ == reader_->end()) { + grpc::Status status = reader_->Finish(); + if (status.ok()) { + *end_of_sequence = true; + return Status::OK(); + } + return GrpcStatusToTfStatus(status); + } + *end_of_sequence = false; + bigtable::Row& row = *iterator_; + Status s = ParseRow(ctx, row, out_tensors); + // Ensure we always advance. + ++iterator_; + return s; + } + + protected: + virtual ::bigtable::RowRange MakeRowRange() = 0; + virtual ::bigtable::Filter MakeFilter() = 0; + virtual Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector* out_tensors) = 0; + + private: + Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (reader_) { + return Status::OK(); + } + + auto rows = MakeRowRange(); + auto filter = MakeFilter(); + + // Note: the this in `this->dataset()` below is necessary due to namespace + // name conflicts. + reader_.reset(new ::bigtable::RowReader( + this->dataset()->table()->table().ReadRows(rows, filter))); + iterator_ = reader_->begin(); + return Status::OK(); + } + + mutex mu_; + std::unique_ptr<::bigtable::RowReader> reader_ GUARDED_BY(mu_); + ::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b6d55a2d334b4193b4e533a7e7228acc15652d9 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -0,0 +1,220 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { + public: + using UnaryDatasetOpKernel::UnaryDatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + BigtableTableResource* table; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = + new Dataset(ctx, input, table, std::move(column_families), + std::move(columns), output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + BigtableTableResource* table, + std::vector column_families, + std::vector columns, + const DataTypeVector& output_types, + std::vector output_shapes) + : GraphDatasetBase(ctx), + input_(input), + table_(table), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + output_types_(output_types), + output_shapes_(std::move(output_shapes)), + filter_(MakeFilter(column_families_, columns_)) { + table_->Ref(); + input_->Ref(); + } + + ~Dataset() override { + table_->Unref(); + input_->Unref(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableLookupDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableLookupDatasetOp::Dataset"; + } + + private: + static ::bigtable::Filter MakeFilter( + const std::vector& column_families, + const std::vector& columns) { + string column_family_regex = RegexFromStringSet(column_families); + string column_regex = RegexFromStringSet(columns); + + return ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), + ::bigtable::Filter::FamilyRegex(column_family_regex), + ::bigtable::Filter::ColumnRegex(column_regex)); + } + + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); // Sequence requests. + std::vector input_tensors; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &input_tensors, end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); + } + if (input_tensors.size() != 1) { + return errors::InvalidArgument( + "Upstream iterator (", dataset()->input_->DebugString(), + ") did not produce a single `tf.string` `tf.Tensor`. It " + "produced ", + input_tensors.size(), " tensors."); + } + if (input_tensors[0].NumElements() == 0) { + return errors::InvalidArgument("Upstream iterator (", + dataset()->input_->DebugString(), + ") return an empty set of keys."); + } + if (input_tensors[0].NumElements() == 1) { + // Single key lookup. + ::grpc::Status status; + auto pair = dataset()->table_->table().ReadRow( + input_tensors[0].scalar()(), dataset()->filter_, status); + if (!status.ok()) { + return GrpcStatusToTfStatus(status); + } + if (!pair.first) { + return errors::DataLoss("Row key '", + input_tensors[0].scalar()(), + "' not found."); + } + TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors)); + } else { + // Batched get. + return errors::Unimplemented( + "BigtableLookupDataset doesn't yet support batched retrieval."); + } + return Status::OK(); + } + + private: + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector* out_tensors) { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::DataLoss("Column ", dataset()->column_families_[i], + ":", dataset()->columns_[i], + " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + BigtableTableResource* table_; + const std::vector column_families_; + const std::vector columns_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const ::bigtable::Filter filter_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), + BigtableLookupDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d5c3cfdaa3d78ea1d4ecc89fcbd5dfc4a46d670 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = new Dataset(ctx, resource, std::move(prefix)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix) + : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtablePrefixKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::bigtable::RowRange MakeRowRange() override { + return ::bigtable::RowRange::Prefix(dataset()->prefix_); + } + ::bigtable::Filter MakeFilter() override { + return ::bigtable::Filter::Chain( + ::bigtable::Filter::CellsRowLimit(1), + ::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string prefix_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), + BigtablePrefixKeyDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fa06052c5d2e057740ba75c984b196d8eb56cbe --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableRangeKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = + new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string start_key, string end_key) + : GraphDatasetBase(ctx), + table_(table), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::bigtable::RowRange MakeRowRange() override { + return ::bigtable::RowRange::Range(dataset()->start_key_, + dataset()->end_key_); + } + ::bigtable::Filter MakeFilter() override { + return ::bigtable::Filter::Chain( + ::bigtable::Filter::CellsRowLimit(1), + ::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string start_key_; + const string end_key_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), + BigtableRangeKeyDatasetOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..11b9bd2bdc3be6b84c8206da7becb3e708a31844 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -0,0 +1,214 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableScanDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()), + errors::InvalidArgument( + "Either prefix or start_key must be specified")); + OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), + errors::InvalidArgument( + "Only one of prefix and start_key can be provided")); + if (!prefix.empty()) { + OP_REQUIRES(ctx, end_key.empty(), + errors::InvalidArgument( + "If prefix is specified, end_key must be empty.")); + } + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + OP_REQUIRES(ctx, !column_families.empty(), + errors::InvalidArgument("`column_families` is empty")); + + float probability = 0; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "probability", &probability)); + OP_REQUIRES( + ctx, probability > 0 && probability <= 1, + errors::InvalidArgument( + "Probability outside the range of (0, 1]. Got: ", probability)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = new Dataset(ctx, resource, std::move(prefix), + std::move(start_key), std::move(end_key), + std::move(column_families), std::move(columns), + probability, output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix, string start_key, string end_key, + std::vector column_families, + std::vector columns, float probability, + const DataTypeVector& output_types, + std::vector output_shapes) + : GraphDatasetBase(ctx), + table_(table), + prefix_(std::move(prefix)), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + column_family_regex_(RegexFromStringSet(column_families_)), + column_regex_(RegexFromStringSet(columns_)), + probability_(probability), + output_types_(output_types), + output_shapes_(std::move(output_shapes)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableScanDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableScanDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::bigtable::RowRange MakeRowRange() override { + if (!dataset()->prefix_.empty()) { + DCHECK(dataset()->start_key_.empty()); + return ::bigtable::RowRange::Prefix(dataset()->prefix_); + } else { + DCHECK(!dataset()->start_key_.empty()) + << "Both prefix and start_key were empty!"; + return ::bigtable::RowRange::Range(dataset()->start_key_, + dataset()->end_key_); + } + } + ::bigtable::Filter MakeFilter() override { + // TODO(saeta): Investigate optimal ordering here. + return ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), + ::bigtable::Filter::FamilyRegex(dataset()->column_family_regex_), + ::bigtable::Filter::ColumnRegex(dataset()->column_regex_), + dataset()->probability_ != 1.0 + ? ::bigtable::Filter::RowSample(dataset()->probability_) + : ::bigtable::Filter::PassAllFilter()); + } + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector* out_tensors) override { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::InvalidArgument( + "Column ", dataset()->column_families_[i], ":", + dataset()->columns_[i], " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + }; + + BigtableTableResource* table_; + const string prefix_; + const string start_key_; + const string end_key_; + const std::vector column_families_; + const std::vector columns_; + const string column_family_regex_; + const string column_regex_; + const float probability_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), + BigtableScanDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f107f169cfa1e9c9158be270323e09250388724 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -0,0 +1,367 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" + +#include "google/bigtable/v2/data.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "re2/re2.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/ptr_util.h" +// #include "util/task/codes.pb.h" + +namespace tensorflow { +namespace { + +void UpdateRow(const ::google::bigtable::v2::Mutation& mut, + std::map* row) { + if (mut.has_set_cell()) { + auto col = + strings::Printf("%s:%s", mut.set_cell().family_name().c_str(), + string(mut.set_cell().column_qualifier()).c_str()); + (*row)[col] = string(mut.set_cell().value()); + } else if (mut.has_delete_from_column()) { + auto col = strings::Printf( + "%s:%s", mut.delete_from_column().family_name().c_str(), + string(mut.delete_from_column().column_qualifier()).c_str()); + row->erase(col); + } else if (mut.has_delete_from_family()) { + auto itr = row->lower_bound(mut.delete_from_family().family_name()); + auto prefix = + strings::Printf("%s:", mut.delete_from_family().family_name().c_str()); + while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) { + row->erase(itr); + } + } else if (mut.has_delete_from_row()) { + row->clear(); + } else { + LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString(); + } +} + +} // namespace + +class SampleRowKeysResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::SampleRowKeysResponse> { + public: + explicit SampleRowKeysResponse(BigtableTestClient* client) + : client_(client) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::SampleRowKeysResponse(); + resp->set_row_key(client_->table_.rows.begin()->first); + resp->set_offset_bytes(0); + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. +}; + +class ReadRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::ReadRowsResponse> { + public: + ReadRowsResponse(BigtableTestClient* client, + google::bigtable::v2::ReadRowsRequest const& request) + : client_(client), request_(request) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::ReadRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + RowFilter filter = MakeRowFilter(); + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::ReadRowsResponse(); + // Send all contents in first response. + for (auto itr = client_->table_.rows.begin(); + itr != client_->table_.rows.end(); ++itr) { + if (filter.AllowRow(itr->first)) { + ::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr; + bool sent_first = false; + for (auto col_itr = itr->second.columns.begin(); + col_itr != itr->second.columns.end(); ++col_itr) { + if (filter.AllowColumn(col_itr->first)) { + chunk = resp->add_chunks(); + if (!sent_first) { + sent_first = true; + chunk->set_row_key(itr->first); + } + auto colon_idx = col_itr->first.find(":"); + CHECK(colon_idx != string::npos) + << "No ':' found in: " << col_itr->first; + chunk->mutable_family_name()->set_value( + string(col_itr->first, 0, colon_idx)); + chunk->mutable_qualifier()->set_value( + string(col_itr->first, ++colon_idx)); + if (!filter.strip_values) { + chunk->set_value(col_itr->second); + } + if (filter.only_one_column) { + break; + } + } + } + if (sent_first) { + // We are sending this row, so set the commit flag on the last chunk. + chunk->set_commit_row(true); + } + } + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + struct RowFilter { + std::set row_set; + std::vector> row_ranges; + double row_sample = 0.0; // Note: currently ignored. + std::unique_ptr col_filter; + bool strip_values = false; + bool only_one_column = false; + + bool AllowRow(const string& row) { + if (row_set.find(row) != row_set.end()) { + return true; + } + for (const auto& range : row_ranges) { + if (range.first <= row && range.second > row) { + return true; + } + } + return false; + } + + bool AllowColumn(const string& col) { + if (col_filter) { + return RE2::FullMatch(col, *col_filter); + } else { + return true; + } + } + }; + + RowFilter MakeRowFilter() { + RowFilter filter; + for (auto i = request_.rows().row_keys().begin(); + i != request_.rows().row_keys().end(); ++i) { + filter.row_set.insert(string(*i)); + } + for (auto i = request_.rows().row_ranges().begin(); + i != request_.rows().row_ranges().end(); ++i) { + if (i->start_key_case() != + google::bigtable::v2::RowRange::kStartKeyClosed || + i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) { + LOG(WARNING) << "Skipping row range that cannot be processed: " + << i->ShortDebugString(); + continue; + } + filter.row_ranges.emplace_back(std::make_pair( + string(i->start_key_closed()), string(i->end_key_open()))); + } + if (request_.filter().has_chain()) { + string family_filter; + string qualifier_filter; + for (auto i = request_.filter().chain().filters().begin(); + i != request_.filter().chain().filters().end(); ++i) { + switch (i->filter_case()) { + case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter: + family_filter = i->family_name_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter: + qualifier_filter = i->column_qualifier_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter: + if (i->cells_per_column_limit_filter() != 1) { + LOG(ERROR) << "Unexpected cells_per_column_limit_filter: " + << i->cells_per_column_limit_filter(); + } + break; + case google::bigtable::v2::RowFilter::kStripValueTransformer: + filter.strip_values = i->strip_value_transformer(); + break; + case google::bigtable::v2::RowFilter::kRowSampleFilter: + LOG(INFO) << "Ignoring row sample directive."; + break; + case google::bigtable::v2::RowFilter::kPassAllFilter: + break; + case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter: + filter.only_one_column = true; + break; + default: + LOG(WARNING) << "Ignoring unknown filter type: " + << i->ShortDebugString(); + } + } + if (family_filter.empty() || qualifier_filter.empty()) { + LOG(WARNING) << "Missing regex!"; + } else { + string regex = strings::Printf("%s:%s", family_filter.c_str(), + qualifier_filter.c_str()); + filter.col_filter.reset(new RE2(regex)); + } + } else { + LOG(WARNING) << "Read request did not have a filter chain specified: " + << request_.filter().DebugString(); + } + return filter; + } + + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. + const google::bigtable::v2::ReadRowsRequest request_; +}; + +class MutateRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::MutateRowsResponse> { + public: + explicit MutateRowsResponse(size_t num_successes) + : num_successes_(num_successes) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::MutateRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + *resp = google::bigtable::v2::MutateRowsResponse(); + for (size_t i = 0; i < num_successes_; ++i) { + auto entry = resp->add_entries(); + entry->set_index(i); + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + const size_t num_successes_; + + mutex mu_; + bool sent_first_message_ = false; +}; + +grpc::Status BigtableTestClient::MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) { + mutex_lock l(mu_); + auto* row = &table_.rows[string(request.row_key())]; + for (int i = 0; i < request.mutations_size(); ++i) { + UpdateRow(request.mutations(i), &row->columns); + } + *response = google::bigtable::v2::MutateRowResponse(); + return grpc::Status::OK; +} +grpc::Status BigtableTestClient::CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "CheckAndMutateRow not implemented."); +} +grpc::Status BigtableTestClient::ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "ReadModifyWriteRow not implemented."); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::ReadRows( + grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) { + return MakeUnique(this, request); +} + +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) { + return MakeUnique(this); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::MutateRows( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) { + mutex_lock l(mu_); + for (auto i = request.entries().begin(); i != request.entries().end(); ++i) { + auto* row = &table_.rows[string(i->row_key())]; + for (auto mut = i->mutations().begin(); mut != i->mutations().end(); + ++mut) { + UpdateRow(*mut, &row->columns); + } + } + return MakeUnique(request.entries_size()); +} + +std::shared_ptr BigtableTestClient::Channel() { + LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " + "cause a crash!"; + return nullptr; +} +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h new file mode 100644 index 0000000000000000000000000000000000000000..dcce6a33a7ce133939fa12db210028771825c290 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ + +#include "google/cloud/bigtable/data_client.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class BigtableTestClient : public ::bigtable::DataClient { + public: + std::string const& project_id() const override { return project_id_; } + std::string const& instance_id() const override { return instance_id_; } + void reset() override { + mutex_lock l(mu_); + table_ = Table(); + } + + grpc::Status MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) override; + + grpc::Status CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) override; + + grpc::Status ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + ReadRows(grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) override; + std::unique_ptr< + grpc::ClientReaderInterface> + SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + MutateRows(grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) override; + + std::shared_ptr Channel() override; + + private: + friend class SampleRowKeysResponse; + friend class ReadRowsResponse; + friend class MutateRowsResponse; + + struct Row { + string row_key; + std::map columns; + }; + struct Table { + std::map rows; + }; + + mutex mu_; + const std::string project_id_ = "testproject"; + const std::string instance_id_ = "testinstance"; + Table table_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9be9ec6e231efb445360c90790179f9f706d352 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { + +namespace { + +class BigtableTestClientOp : public OpKernel { + public: + explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~BigtableTestClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK(ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr client( + new BigtableTestClient()); + // Note: must make explicit copies to sequence + // them before the move of client. + string project_id = client->project_id(); + string instance_id = client->instance_id(); + *ret = new BigtableClientResource( + std::move(project_id), + std::move(instance_id), std::move(client)); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU), + BigtableTestClientOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd362f7de558cf7269526341f05014296b34f6b0 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc @@ -0,0 +1,279 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void WriteCell(const string& row, const string& family, const string& column, + const string& value, ::bigtable::noex::Table* table) { + ::bigtable::SingleRowMutation mut(row); + mut.emplace_back(::bigtable::SetCell(family, column, value)); + table->Apply(std::move(mut)); +} + +TEST(BigtableTestClientTest, EmptyRowRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + ::bigtable::RowSet rowset; + rowset.Append("r1"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!"; + EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows."; +} + +TEST(BigtableTestClientTest, SingleRowWriteAndRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + + ::bigtable::RowSet rowset("r1"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "No rows were returned in response!"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()); + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::bigtable::RowSet rowset("r1"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::bigtable::RowSet rowset("r1", "r2", "r3"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, ColumnFiltering) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), ::bigtable::Filter::FamilyRegex("f1"), + ::bigtable::Filter::ColumnRegex("c1")); + auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, RowKeys) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), ::bigtable::Filter::CellsRowLimit(1), + ::bigtable::Filter::StripValueTransformer()); + auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..17ecc3dcb24f35a80cbc904ea11df3eff3fce6b9 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(saeta): Add support for setting ClientOptions values. +REGISTER_OP("BigtableClient") + .Attr("project_id: string") + .Attr("instance_id: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Add support for Application Profiles. +// See https://cloud.google.com/bigtable/docs/app-profiles for more info. +REGISTER_OP("BigtableTable") + .Input("client: resource") + .Attr("table_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("table: resource") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("DatasetToBigtable") + .Input("table: resource") + .Input("input_dataset: variant") + .Input("column_families: string") + .Input("columns: string") + .Input("timestamp: int64") + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("BigtableLookupDataset") + .Input("keys_dataset: variant") + .Input("table: resource") + .Input("column_families: string") + .Input("columns: string") + .Output("handle: variant") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtablePrefixKeyDataset") + .Input("table: resource") + .Input("prefix: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableRangeKeyDataset") + .Input("table: resource") + .Input("start_key: string") + .Input("end_key: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Support continuing despite bad data (e.g. empty string, or +// skip incomplete row.) +REGISTER_OP("BigtableScanDataset") + .Input("table: resource") + .Input("prefix: string") + .Input("start_key: string") + .Input("end_key: string") + .Input("column_families: string") + .Input("columns: string") + .Input("probability: float") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7d02458f63d547000f00b184b3d5e3c5007fb72 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("BigtableTestClient") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/control_flow/__init__.py b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py similarity index 73% rename from tensorflow/contrib/control_flow/__init__.py rename to tensorflow/contrib/bigtable/python/kernel_tests/__init__.py index 582af2cf10a3d92dd8611b0f2826625e3acfb099..292d8f4e51abbbd89d68b47febd86b7297bb8ed2 100644 --- a/tensorflow/contrib/control_flow/__init__.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py @@ -13,19 +13,8 @@ # limitations under the License. # ============================================================================== -"""New implementations of TF control flow ops. - -@@cond_v2 -""" +"""This module contains tests for the bigtable integration.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -# pylint: disable=unused-import -from tensorflow.contrib.control_flow.python.cond_v2 import cond_v2 -# pylint: enable=unused-import - -from tensorflow.python.util.all_util import remove_undocumented - -remove_undocumented(__name__) diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d33a66f2dfbecd0dc1082fd98973660ce9a93931 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bigtable Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import bigtable +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable_test.so")) + + +class BigtableOpsTest(test.TestCase): + COMMON_ROW_KEYS = ["r1", "r2", "r3"] + COMMON_VALUES = ["v1", "v2", "v3"] + + def setUp(self): + self._client = gen_bigtable_test_ops.bigtable_test_client() + table = gen_bigtable_ops.bigtable_table(self._client, "testtable") + self._table = bigtable.BigTable("testtable", None, table) + + def _makeSimpleDataset(self): + output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS) + output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES) + return dataset_ops.Dataset.zip((output_rows, output_values)) + + def _writeCommonValues(self, sess): + output_ds = self._makeSimpleDataset() + write_op = self._table.write(output_ds, ["cf1"], ["c1"]) + sess.run(write_op) + + def runReadKeyTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected = list(self.COMMON_ROW_KEYS) + expected.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output), + "Unequal at step %d: want: %s, got: %s" % (i, want, output)) + + def testReadPrefixKeys(self): + self.runReadKeyTest(self._table.keys_by_prefix_dataset("r")) + + def testReadRangeKeys(self): + self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) + + def runScanTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_keys.reverse() + expected_values = list(self.COMMON_VALUES) + expected_values.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected_keys.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0])) + want = expected_values.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[1]), + "Unequal values at step: %d: want: %s, got: %s" % (i, want, + output[1])) + + def testScanPrefixStringCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1="c1")) + + def testScanPrefixListCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1=["c1"])) + + def testScanRangeStringCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1")) + + def testScanRangeListCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"])) + + def testLookup(self): + ds = self._table.keys_by_prefix_dataset("r") + ds = ds.apply(self._table.lookup_columns(cf1="c1")) + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_values = list(self.COMMON_VALUES) + expected_tuples = zip(expected_keys, expected_values) + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i, elem in enumerate(expected_tuples): + output = sess.run(n) + self.assertEqual( + compat.as_bytes(elem[0]), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0]))) + self.assertEqual( + compat.as_bytes(elem[1]), compat.as_bytes(output[1]), + "Unequal values at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1]))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpointable/data_structures_base.py b/tensorflow/contrib/bigtable/python/ops/__init__.py similarity index 65% rename from tensorflow/python/training/checkpointable/data_structures_base.py rename to tensorflow/contrib/bigtable/python/ops/__init__.py index f1b2cf105b81490ea12e0a667f53fb02d45135c9..36d75b0d7068a650347a5e17f4727a5432d8752f 100644 --- a/tensorflow/python/training/checkpointable/data_structures_base.py +++ b/tensorflow/contrib/bigtable/python/ops/__init__.py @@ -1,5 +1,4 @@ -"""A trivial base class to avoid circular imports for isinstance checks.""" -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +"""This module contains the Python API for the Cloud Bigtable integration.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function - - -from tensorflow.python.training.checkpointable import base as checkpointable_lib - - -class CheckpointableDataStructureBase(checkpointable_lib.CheckpointableBase): - """Base class for data structures which contain checkpointable objects.""" - - pass diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py new file mode 100644 index 0000000000000000000000000000000000000000..a54e020ed770ed24f6ede1aac5ed4674a41b0e52 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -0,0 +1,480 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Python API for TensorFlow's Bigtable integration. + +TensorFlow has support for reading from and writing to Cloud Bigtable. To use +the Bigtable TensorFlow integration, first create a BigtableClient (which +configures your connection to Cloud Bigtable), and then open a Table. The Table +object then allows you to create numerous @{tf.data.Dataset}s to read data, or +write a @{tf.data.Dataset} object to the underlying Bigtable Table. + +For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import iteritems + +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import resource_loader + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable.so")) + + +class BigtableClient(object): + """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. + + BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the + `table` method to open a Bigtable Table. + """ + + def __init__(self, project_id, instance_id): + """Creates a BigtableClient that can be used to open connections to tables. + + Args: + project_id: A string representing the GCP project id to connect to. + instance_id: A string representing the Bigtable instance to connect to. + """ + self._project_id = project_id + self._instance_id = instance_id + self._resource = gen_bigtable_ops.bigtable_client(project_id, instance_id) + + def table(self, name, snapshot=None): + """Opens a table and returns a `BigTable` object. + + Args: + name: A `tf.string` `tf.Tensor` name of the table to open. + snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to + request the creation of a snapshot. (Note: currently unimplemented.) + + Returns: + A `BigTable` python object representing the operations available on the + table. + """ + # TODO(saeta): Implement snapshot functionality. + table = gen_bigtable_ops.bigtable_table(self._resource, name) + return BigTable(name, snapshot, table) + + +class BigTable(object): + """BigTable is the entrypoint for reading and writing data in Cloud Bigtable. + + This BigTable class is the python representation of the Cloud Bigtable table + within TensorFlow. Methods on this class allow data to be read from and + written to the Cloud Bigtable service in flexible and high performance + manners. + """ + + # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface. + # TODO(saeta): Consider variant tensors instead of resources (while supporting + # connection pooling). + + def __init__(self, name, snapshot, resource): + self._name = name + self._snapshot = snapshot + self._resource = resource + + def lookup_columns(self, *args, **kwargs): + """Retrieves the values of columns for a dataset of keys. + + Example usage: + ``` + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns(("cf1", "image"), + ("cf2", "label"), + ("cf2", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Alternatively, you can use keyword arguments to specify the columns to + capture. Example (same as above, rewritten): + ``` + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns( + cf1="image", cf2=("label", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Note: certain kwargs keys are reserved, and thus some column families cannot + be identified using the kwargs syntax. Instead, please use the args syntax. + This list includes: + - 'name' + This list can change at any time. + + Args: + *args: A list of tuples containing (column family, column name) pairs. + **kwargs: Column families and + + Returns: + A function that can be passed to `tf.data.Dataset.apply` to retrieve the + values of columns for the rows. + """ + table = self # Capture self + normalized = args + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + def _apply_fn(dataset): + # TODO(saeta): Verify dataset's types are correct! + return _BigtableLookupDataset(dataset, table, normalized) + + return _apply_fn + + def keys_by_range_dataset(self, start, end): + """Retrieves all row keys between start and end. + + Note: it does NOT retrieve the values of columns. + + Args: + start: The start row key. The row keys for rows after start (inclusive) + will be retrieved. + end: (Optional.) The end row key. Rows up to (but not including) end will + be retrieved. If end is None, all subsequent row keys will be retrieved. + + Returns: + A @{tf.data.Dataset} containing `tf.string` Tensors corresponding to all + of the row keys between `start` and `end`. + """ + # TODO(saeta): Make inclusive / exclusive configurable? + if end is None: + end = "" + return _BigtableRangeKeyDataset(self, start, end) + + def keys_by_prefix_dataset(self, prefix): + """Retrieves the row keys matching a given prefix. + + Args: + prefix: All row keys that begin with `prefix` in the table will be + retrieved. + + Returns: + A @{tf.data.Dataset}. containing `tf.string` Tensors corresponding to all + of the row keys matching that prefix. + """ + return _BigtablePrefixKeyDataset(self, prefix) + + def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): + """Retrieves row (including values) from the Bigtable service. + + Rows with row-key prefixed by `prefix` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + prefix: The prefix all row keys muat match to be retrieved for prefix- + based scans. + probability: Probabilistically sample rows. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + + def scan_range(self, start, end, probability=None, columns=None, **kwargs): + """Retrieves rows (including values) from the Bigtable service. + + Rows with row-keys between `start` and `end` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + start: The start of the range when scanning by range. + end: (Optional.) The end of the range when scanning by range. + probability: Probabilistically sample rows. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + return _BigtableScanDataset(self, "", start, end, normalized, probability) + + def write(self, dataset, column_families, columns, timestamp=None): + """Writes a dataset to the table. + + Args: + dataset: A @{tf.data.Dataset} to be written to this table. It must produce + a list of number-of-columns+1 elements, all of which must be strings. + The first value will be used as the row key, and subsequent values will + be used as cell values for the corresponding columns from the + corresponding column_families and columns entries. + column_families: A @{tf.Tensor} of `tf.string`s corresponding to the + column names to store the dataset's elements into. + columns: A `tf.Tensor` of `tf.string`s corresponding to the column names + to store the dataset's elements into. + timestamp: (Optional.) An int64 timestamp to write all the values at. + Leave as None to use server-provided timestamps. + + Returns: + A @{tf.Operation} that can be run to perform the write. + + Raises: + ValueError: If there are unexpected or incompatible types, or if the + number of columns and column_families does not match the output of + `dataset`. + """ + if timestamp is None: + timestamp = -1 # Bigtable server provided timestamp. + for tensor_type in nest.flatten(dataset.output_types): + if tensor_type != dtypes.string: + raise ValueError("Not all elements of the dataset were `tf.string`") + for shape in nest.flatten(dataset.output_shapes): + if not shape.is_compatible_with(tensor_shape.scalar()): + raise ValueError("Not all elements of the dataset were scalars") + if len(column_families) != len(columns): + raise ValueError("len(column_families) != len(columns)") + if len(nest.flatten(dataset.output_types)) != len(columns) + 1: + raise ValueError("A column name must be specified for every component of " + "the dataset elements. (e.g.: len(columns) != " + "len(dataset.output_types))") + return gen_bigtable_ops.dataset_to_bigtable( + self._resource, + dataset._as_variant_tensor(), # pylint: disable=protected-access + column_families, + columns, + timestamp) + + +class _BigtableKeyDataset(dataset_ops.Dataset): + """_BigtableKeyDataset is an abstract class representing the keys of a table. + """ + + def __init__(self, table): + """Constructs a _BigtableKeyDataset. + + Args: + table: a Bigtable class. + """ + super(_BigtableKeyDataset, self).__init__() + self._table = table + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.TensorShape([]) + + @property + def output_types(self): + return dtypes.string + + +class _BigtablePrefixKeyDataset(_BigtableKeyDataset): + """_BigtablePrefixKeyDataset represents looking up keys by prefix. + """ + + def __init__(self, table, prefix): + super(_BigtablePrefixKeyDataset, self).__init__(table) + self._prefix = prefix + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_prefix_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix) + + +class _BigtableRangeKeyDataset(_BigtableKeyDataset): + """_BigtableRangeKeyDataset represents looking up keys by range. + """ + + def __init__(self, table, start, end): + super(_BigtableRangeKeyDataset, self).__init__(table) + self._start = start + self._end = end + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_range_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + start_key=self._start, + end_key=self._end) + + +class _BigtableLookupDataset(dataset_ops.Dataset): + """_BigtableLookupDataset represents a dataset that retrieves values for keys. + """ + + def __init__(self, dataset, table, normalized): + self._num_outputs = len(normalized) + 1 # 1 for row key + self._dataset = dataset + self._table = table + self._normalized = normalized + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._as_variant_tensor(), + table=self._table._resource, + column_families=self._column_families, + columns=self._columns) + + +class _BigtableScanDataset(dataset_ops.Dataset): + """_BigtableScanDataset represents a dataset that retrieves keys and values. + """ + + def __init__(self, table, prefix, start, end, normalized, probability): + self._table = table + self._prefix = prefix + self._start = start + self._end = end + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + self._probability = probability + self._num_outputs = len(normalized) + 1 # 1 for row key + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_scan_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end, + column_families=self._column_families, + columns=self._columns, + probability=self._probability) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 8cff1a3bb1d11aff6a264636291a7149b40de516..ef0e80cd0997bc0e95cd0d150e87db144a2dde44 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -15,8 +15,9 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "custom_export_strategy", + ":custom_export_strategy", ":custom_loss_head", + ":distillation_loss", ":estimator", ":model", ":trainer_hooks", @@ -144,6 +145,7 @@ py_library( srcs = ["dnn_tree_combined_estimator.py"], srcs_version = "PY2AND3", deps = [ + ":distillation_loss", ":estimator_utils", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", @@ -156,6 +158,17 @@ py_library( ], ) +py_library( + name = "distillation_loss", + srcs = ["distillation_loss.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + ], +) + py_test( name = "dnn_tree_combined_estimator_test", size = "medium", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9aacc5534329d1302b25dcfab678f9adb8f773f6 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utill functions for distillation loss. + +The distillation loss_fn will be called with the following: + +Args: + dnn_logits: Tensor of logits from the dnn, treated as the "target". This will + be the output of a call to tf.stop_gradient(). + tree_logits: Tensor of logits from the tree, treated as the "predictions". + example_weights: Tensor of example weights, or a single scalar. + +Returns: + A scalar indicating the reduced loss for that batch of examples. + +Note: we calls the loss_fn defined in contrib head, which is computing two +losses, first one for training and second one for reporting. We only take the +first one here. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + + +def _logits_to_label_for_tree(logits, n_classes): + if n_classes == 2: + return math_ops.sigmoid(logits) + else: + return nn.softmax(logits) + + +def create_dnn_to_tree_squared_loss_fn(n_classes): + """Returns a squared loss function for dnn to tree distillation.""" + + def _dnn_to_tree_squared_loss(dnn_logits, tree_logits, example_weights): + return head_lib._mean_squared_loss( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=_logits_to_label_for_tree(tree_logits, n_classes), + weights=example_weights)[0] + + return _dnn_to_tree_squared_loss + + +def create_dnn_to_tree_cross_entropy_loss_fn(n_classes): + """Returns a cross entropy loss function for dnn to tree distillation.""" + + def _dnn_to_tree_cross_entropy_loss(dnn_logits, tree_logits, example_weights): + if n_classes == 2: + return head_lib._log_loss_with_two_classes( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=tree_logits, + weights=example_weights)[0] + else: + return head_lib._softmax_cross_entropy_loss( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=tree_logits, + weights=example_weights)[0] + + return _dnn_to_tree_cross_entropy_loss diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 911d87fa10570382ee5f03edfc1bfd1d116c8360..7eb429b636a5193a124dd9b0c020dae6cac910cb 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -24,7 +24,9 @@ from __future__ import division from __future__ import print_function import six + from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops @@ -35,11 +37,13 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import training_util @@ -77,6 +81,7 @@ def _dnn_tree_combined_model_fn(features, predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """DNN and GBDT combined model_fn. @@ -117,6 +122,13 @@ def _dnn_tree_combined_model_fn(features, set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. @@ -132,6 +144,12 @@ def _dnn_tree_combined_model_fn(features, if not dnn_feature_columns: raise ValueError("dnn_feature_columns must be specified") + if dnn_to_tree_distillation_param: + if not predict_with_tree_only: + logging.warning("update predict_with_tree_only to True since distillation" + "is specified.") + predict_with_tree_only = True + # Build DNN Logits. dnn_parent_scope = "dnn" dnn_partitioner = dnn_input_layer_partitioner or ( @@ -225,6 +243,25 @@ def _dnn_tree_combined_model_fn(features, def _tree_train_op_fn(loss): """Returns the op to optimize the loss.""" + if dnn_to_tree_distillation_param: + loss_weight, loss_fn = dnn_to_tree_distillation_param + weight_tensor = head_lib._weight_tensor( # pylint: disable=protected-access + features, head.weight_column_name) + dnn_logits_fixed = array_ops.stop_gradient(dnn_logits) + + if loss_fn is None: + # we create the loss_fn similar to the head loss_fn for + # multi_class_head used previously as the default one. + n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension + loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn( + n_classes) + + dnn_to_tree_distillation_loss = loss_weight * loss_fn( + dnn_logits_fixed, tree_logits, weight_tensor) + summary.scalar("dnn_to_tree_distillation_loss", + dnn_to_tree_distillation_loss) + loss += dnn_to_tree_distillation_loss + update_op = gbdt_model.train(loss, predictions_dict, labels) with ops.control_dependencies( [update_op]), (ops.colocate_with(global_step)): @@ -232,7 +269,7 @@ def _dnn_tree_combined_model_fn(features, return update_op if predict_with_tree_only: - if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT: + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.INFER: tree_train_logits = tree_logits else: tree_train_logits = control_flow_ops.cond( @@ -331,6 +368,7 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """Initializes a DNNBoostedTreeCombinedClassifier instance. @@ -378,6 +416,13 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ @@ -409,6 +454,7 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): predict_with_tree_only=predict_with_tree_only, tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( @@ -442,6 +488,7 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """Initializes a DNNBoostedTreeCombinedRegressor instance. @@ -489,6 +536,13 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ @@ -525,6 +579,7 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): predict_with_tree_only=predict_with_tree_only, tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedRegressor, self).__init__( @@ -559,6 +614,7 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, + dnn_to_tree_distillation_param=None, use_core_versions=False): """Initializes a DNNBoostedTreeCombinedEstimator instance. @@ -601,6 +657,13 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ @@ -626,6 +689,7 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): predict_with_tree_only=predict_with_tree_only, tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index f495edc62f0909880c170ccb4cf5d11e3f20f55c..9b7acfa664b0398216b5a7fb904960d8363929d6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -131,6 +131,30 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) + def testFitAndEvaluateWithDistillation(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[feature_column.real_valued_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[feature_column.real_valued_column("x")], + dnn_to_tree_distillation_param=(1, None)) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 56ff00b39062d57c813633c98c765e077dd4c262..1b7f59ea4218355a13f1df7264352bd68503bd19 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -37,6 +37,7 @@ class BaseSplitHandler(object): gradient_shape, hessian_shape, multiclass_strategy, + loss_uses_sum_reduction=False, name=None): """Constructor for BaseSplitHandler. @@ -51,6 +52,8 @@ class BaseSplitHandler(object): gradient_shape: A TensorShape, containing shape of gradients. hessian_shape: A TensorShape, containing shape of hessians. multiclass_strategy: Strategy describing how to treat multiclass problems. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ self._l1_regularization = l1_regularization @@ -62,6 +65,7 @@ class BaseSplitHandler(object): self._multiclass_strategy = multiclass_strategy self._hessian_shape = hessian_shape self._gradient_shape = gradient_shape + self._loss_uses_sum_reduction = loss_uses_sum_reduction def scheduled_reads(self): """Returns the list of `ScheduledOp`s required for update_stats.""" diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 9f78ab20242800fd8af7ad049d5970fbe26ec0ea..bf686237ff696dadad9713d26bf784d7442b80d0 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -23,6 +23,7 @@ from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -44,6 +45,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -62,6 +64,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(EqualitySplitHandler, self).__init__( @@ -73,6 +77,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): gradient_shape=gradient_shape, hessian_shape=hessian_shape, multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction, name=name) self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( init_stamp_token, @@ -173,6 +178,11 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # pair. num_minibatches, partition_ids, feature_ids, gradients, hessians = ( self._stats_accumulator.flush(stamp_token, next_stamp_token)) + # For sum_reduction, we don't need to divide by number of minibatches. + + num_minibatches = control_flow_ops.cond( + ops.convert_to_tensor(self._loss_uses_sum_reduction), + lambda: math_ops.to_int64(1), lambda: num_minibatches) partition_ids, gains, split_infos = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=num_minibatches, @@ -187,7 +197,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): tree_complexity_regularization=self._tree_complexity_regularization, min_node_weight=self._min_node_weight, bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy,)) + multiclass_strategy=self._multiclass_strategy)) # There are no warm-up rounds needed in the equality column handler. So we # always return ready. are_splits_ready = constant_op.constant(True) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index 0b65eba2a76273a81f1464ed7639f0c0760e0050..ef253e7cec4e8a96b360ced32b59398c2e2c9680 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -90,7 +90,17 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): empty_hessians, example_weights, is_active=array_ops.constant([True, True])) - with ops.control_dependencies([update_1]): + update_2 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + + with ops.control_dependencies([update_1, update_2]): are_splits_ready, partitions, gains, splits = ( split_handler.make_splits(0, 1, class_id)) are_splits_ready, partitions, gains, splits = (sess.run( @@ -159,6 +169,129 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) + def testGenerateFeatureSplitCandidatesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Feature ID | + # i0 | (0.2, 0.12) | 0 | 1,2 | + # i1 | (-0.5, 0.07) | 0 | | + # i2 | (1.2, 0.2) | 0 | 2 | + # i3 | (4.0, 0.13) | 1 | 1 | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = [0, 0, 0, 1] + indices = [[0, 0], [0, 1], [2, 0], [3, 0]] + values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = categorical_split_handler.EqualitySplitHandler( + l1_regularization=0.1, + l2_regularization=1, + tree_complexity_regularization=0, + min_node_weight=0, + sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]), + feature_column_group_id=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + init_stamp_token=0, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_2 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1, update_2]): + are_splits_ready, partitions, gains, splits = ( + split_handler.make_splits(0, 1, class_id)) + are_splits_ready, partitions, gains, splits = ( + sess.run([are_splits_ready, partitions, gains, splits])) + self.assertTrue(are_splits_ready) + self.assertAllEqual([0, 1], partitions) + + # Check the split on partition 0. + # -(0.4 + 2.4 - 0.1) / (0.24 + 0.4 + 1) + expected_left_weight = -1.6463414634146338 + + # (0.4 + 2.4 - 0.1) ** 2 / (0.24 + 0.4 + 1) + expected_left_gain = 4.445121951219511 + + # -(-1 + 0.1) / (0.14 + 1) + expected_right_weight = 0.789473684211 + + # (-1 + 0.1) ** 2 / (0.14 + 1) + expected_right_gain = 0.710526315789 + + # (0.4 + -1 + 2.4 - 0.1) ** 2 / (0.24 + 0.14 + 0.4 + 1) + expected_bias_gain = 1.6235955056179772 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(2, split_node.feature_id) + + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + # Check the split on partition 1. + # (-8 + 0.1) / (0.26 + 1) + expected_left_weight = -6.26984126984 + # (-8 + 0.1) ** 2 / (0.26 + 1) + expected_left_gain = 49.5317460317 + expected_right_weight = 0 + expected_right_gain = 0 + # (-8 + 0.1) ** 2 / (0.26 + 1) + expected_bias_gain = 49.5317460317 + + # Verify candidate for partition 1, there's only one active feature here + # so zero gain is expected. + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + self.assertAllClose(0.0, gains[1], 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(1, split_node.feature_id) + def testGenerateFeatureSplitCandidatesMulticlass(self): with self.test_session() as sess: # Batch size is 4, 2 gradients per each instance. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 409a2d8f46c331c13aec10542c4967d50575e94a..df0bec1fe363e07bbff6b059e86076239bd605e9 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -99,6 +99,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -117,6 +118,8 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(InequalitySplitHandler, self).__init__( @@ -128,7 +131,8 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): feature_column_group_id=feature_column_group_id, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=multiclass_strategy) + multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction) self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( init_stamp_token, gradient_shape, @@ -160,6 +164,7 @@ class DenseSplitHandler(InequalitySplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -179,6 +184,8 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(DenseSplitHandler, self).__init__( @@ -193,7 +200,8 @@ class DenseSplitHandler(InequalitySplitHandler): name=name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=multiclass_strategy) + multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction) self._dense_float_column = dense_float_column # Register dense_make_stats_update function as an Op to the graph. g = ops.get_default_graph() @@ -255,15 +263,15 @@ class DenseSplitHandler(InequalitySplitHandler): next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, - self._min_node_weight)) + self._min_node_weight, self._loss_uses_sum_reduction)) return are_splits_ready, partition_ids, gains, split_infos -def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, - stamp_token, next_stamp_token, multiclass_strategy, - class_id, feature_column_id, l1_regularization, - l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional): +def _make_dense_split( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): """Function that builds splits for a dense feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( @@ -291,7 +299,10 @@ def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) - + # For sum_reduction, we don't need to divide by number of minibatches. + num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, + lambda: math_ops.to_int64(1), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) @@ -329,6 +340,7 @@ class SparseSplitHandler(InequalitySplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -348,6 +360,8 @@ class SparseSplitHandler(InequalitySplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(SparseSplitHandler, self).__init__( @@ -362,6 +376,7 @@ class SparseSplitHandler(InequalitySplitHandler): hessian_shape=hessian_shape, multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction, name=name) self._sparse_float_column = sparse_float_column @@ -424,15 +439,15 @@ class SparseSplitHandler(InequalitySplitHandler): next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, - self._min_node_weight)) + self._min_node_weight, self._loss_uses_sum_reduction)) return are_splits_ready, partition_ids, gains, split_infos -def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, - stamp_token, next_stamp_token, multiclass_strategy, - class_id, feature_column_id, l1_regularization, - l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional): +def _make_sparse_split( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): """Function that builds splits for a sparse feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( @@ -460,7 +475,9 @@ def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) - + num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, + lambda: math_ops.to_int64(1), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) @@ -498,17 +515,18 @@ def _specialize_make_split(func, is_multi_dimentional): dtypes.float32, dtypes.float32, dtypes.float32, + dtypes.bool, noinline=True) def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, next_stamp_token, multiclass_strategy, class_id, feature_column_id, l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight): + min_node_weight, loss_uses_sum_reduction): """Function that builds splits for a sparse feature column.""" - return func( - quantile_accumulator_handle, stats_accumulator_handle, stamp_token, - next_stamp_token, multiclass_strategy, class_id, feature_column_id, - l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional) + return func(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, class_id, + feature_column_id, l1_regularization, l2_regularization, + tree_complexity_regularization, min_node_weight, + is_multi_dimentional, loss_uses_sum_reduction) return f diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 2f2c2302113bf59d6a065d5005c934dc76c2148d..d59732cf92eb85e88732ac5a17dccf475ae5342f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -182,6 +182,144 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Dense Quantile | + # i0 | (0.2, 0.12) | 0 | 1 | + # i1 | (-0.5, 0.07) | 0 | 1 | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + class_id = -1 + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + split_handler = ordinal_split_handler.DenseSplitHandler( + l1_regularization=0.2, + l2_regularization=2., + tree_complexity_regularization=0., + min_node_weight=0., + epsilon=0.001, + num_quantiles=10, + feature_column_group_id=0, + dense_float_column=dense_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_3 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2, update_3]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + + # Check the split on partition 0. + # -(2.4 - 0.2) / (0.4 + 2) + expected_left_weight = -0.91666 + + # expected_left_weight * -(2.4 - 0.2) + expected_left_gain = 2.016666666666666 + + # -(-1 + 0.4 + 0.2) / (0.38 + 2) + expected_right_weight = 0.1680672 + + # expected_right_weight * -(-1 + 0.4 + 0.2) + expected_right_gain = 0.0672268907563025 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain = 0.9208633093525178 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.dense_float_binary_split + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertAllClose(0.3, split_node.threshold, 0.00001) + + # Check the split on partition 1. + # (-8 + 0.2) / (0.26 + 2) + expected_left_weight = -3.4513274336283186 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.dense_float_binary_split + self.assertAllClose(0.0, gains[1], 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) @@ -798,6 +936,139 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Sparse Quantile | + # i0 | (0.2, 0.12) | 0 | 1 | + # i1 | (-0.5, 0.07) | 0 | N/A | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + example_partitions = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + indices = array_ops.constant([[0, 0], [2, 0], [3, 0]], dtype=dtypes.int64) + values = array_ops.constant([0.52, 0.3, 0.52]) + sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0.0, + l2_regularization=4.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_3 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2, update_3]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + # Check the split on partition 0. + # -(0.4 + 2.4) / (0.24 + 0.4 + 4) + expected_left_weight = -0.603448275862069 + # (0.4 + 2.4) ** 2 / (0.24 + 0.4 + 4) + expected_left_gain = 1.689655172413793 + # 1 / (0.14 + 4) + expected_right_weight = 0.24154589371980678 + # 1 ** 2 / (0.14 + 4) + expected_right_gain = 0.24154589371980678 + # (0.4 + 2.4 - 1) ** 2 / (0.24 + 0.4 + 0.14 + 4) + expected_bias_gain = 0.6778242677824265 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_right + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.52, split_node.split.threshold) + + # Check the split on partition 1. + expected_left_weight = -1.8779342723004695 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertAllClose(0.0, gains[1]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.52, split_node.split.threshold) + def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 47698d45c81478f2b694aaadc603f742c44d5351..1ee7f2395ea2ad71a7d380a1cc8f9a77bd4782b3 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import device_setter @@ -61,6 +62,13 @@ USED_HANDLERS_MASK = "used_handlers_mask" LEAF_INDEX = "leaf_index" _FEATURE_NAME_TEMPLATE = "%s_%d" +# Keys in Training state. +GBDTTrainingState = collections.namedtuple("GBDTTrainingState", [ + "num_layer_examples", "num_layer_steps", "num_layers", "active_tree", + "active_layer", "continue_centering", "bias_stats_accumulator", + "steps_accumulator", "handlers" +]) + def _get_column_by_index(tensor, indices): """Returns columns from a 2-D tensor by index.""" @@ -276,6 +284,7 @@ class GradientBoostedDecisionTreeModel(object): learner_config, features, logits_dimension, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS, feature_columns=None, use_core_columns=False, output_leaf_index=False): @@ -292,7 +301,10 @@ class GradientBoostedDecisionTreeModel(object): learner_config: A learner config. features: `dict` of `Tensor` objects. logits_dimension: An int, the dimension of logits. + loss_reduction: Either `SUM_OVER_NONZERO_WEIGHTS` (mean) or `SUM`. feature_columns: A list of feature columns. + use_core_columns: A boolean specifying whether core feature columns are + used. output_leaf_index: A boolean variable indicating whether to output leaf index into predictions dictionary. @@ -315,6 +327,13 @@ class GradientBoostedDecisionTreeModel(object): self._center_bias = center_bias self._examples_per_layer = examples_per_layer + # Check loss reduction value. + if (loss_reduction != losses.Reduction.SUM and + loss_reduction != losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): + raise ValueError( + "Invalid loss reduction is provided: %s." % loss_reduction) + self._loss_reduction = loss_reduction + # Fill in the defaults. if (learner_config.multi_class_strategy == learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED): @@ -325,6 +344,19 @@ class GradientBoostedDecisionTreeModel(object): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + if logits_dimension == 1 or learner_config.multi_class_strategy == ( + learner_pb2.LearnerConfig.TREE_PER_CLASS): + self._gradient_shape = tensor_shape.scalar() + self._hessian_shape = tensor_shape.scalar() + else: + self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) + if (learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.FULL_HESSIAN): + self._hessian_shape = tensor_shape.TensorShape( + ([logits_dimension, logits_dimension])) + else: + # Diagonal hessian strategy. + self._hessian_shape = tensor_shape.TensorShape(([logits_dimension])) if (learner_config.growing_mode == learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER @@ -359,6 +391,7 @@ class GradientBoostedDecisionTreeModel(object): sparse_int_values, sparse_int_shapes) = extract_features( features, self._feature_columns, use_core_columns) logging.info("Active Feature Columns: " + str(fc_names)) + logging.info("Learner config: " + str(learner_config)) self._fc_names = fc_names self._dense_floats = dense_floats self._sparse_float_indices = sparse_float_indices @@ -522,17 +555,30 @@ class GradientBoostedDecisionTreeModel(object): return self._predict_and_return_dict(self._ensemble_handle, ensemble_stamp, mode) - def train(self, loss, predictions_dict, labels): - """Grows a new tree and adds it to the ensemble. + def _get_class_id(self, predictions_dict): + # Handle different multiclass strategies. + if (self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + self._logits_dimension != 1): + # Choose the class for which the tree is built (one vs rest). + return math_ops.to_int32( + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) + return constant_op.constant(-1, dtype=dtypes.int32) + + def update_stats(self, loss, predictions_dict): + """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. - labels: Rank 2 `Tensor` representing labels per example. Returns: - An op that adds a new tree to the ensemble. + Three values: + - An op that adds a new tree to the ensemble, and + - An op that increments the stamp but removes all the trees and resets + the handlers. This can be used to reset the state of the ensemble. + - A dict containing the training state. Raises: ValueError: if inputs are not valid. @@ -556,13 +602,10 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = constant_op.constant(-1, dtype=dtypes.int32) + class_id = self._get_class_id(predictions_dict) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() - if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. hessians = gradients_impl.gradients( @@ -579,11 +622,6 @@ class GradientBoostedDecisionTreeModel(object): hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) - - # Choose the class for which the tree is built (one vs rest). - class_id = math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) - # Use class id tensor to get the column with that index from gradients # and hessians. squeezed_gradients = array_ops.squeeze( @@ -592,15 +630,10 @@ class GradientBoostedDecisionTreeModel(object): _get_column_by_index(hessians, class_id)) else: # Other multiclass strategies. - gradient_shape = tensor_shape.TensorShape([self._logits_dimension]) - if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: - hessian_shape = tensor_shape.TensorShape( - ([self._logits_dimension, self._logits_dimension])) hessian_list = self._full_hessian(gradients, predictions) else: # Diagonal hessian strategy. - hessian_shape = tensor_shape.TensorShape(([self._logits_dimension])) hessian_list = self._diagonal_hessian(gradients, predictions) squeezed_gradients = gradients @@ -608,7 +641,7 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = hessians # Get the weights for each example for quantiles calculation, - weights = self._get_weights(hessian_shape, squeezed_hessians) + weights = self._get_weights(self._hessian_shape, squeezed_hessians) # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 @@ -622,6 +655,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.regularization.tree_complexity, dtypes.float32) min_node_weight = constant_op.constant( self._learner_config.constraints.min_node_weight, dtypes.float32) + loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM + loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) epsilon = 0.01 num_quantiles = 100 strategy_tensor = constant_op.constant(strategy) @@ -635,15 +670,18 @@ class GradientBoostedDecisionTreeModel(object): l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=dense_float_column_idx, + feature_column_group_id=constant_op.constant( + dense_float_column_idx), epsilon=epsilon, num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, - init_stamp_token=init_stamp_token)) + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction, + )) fc_name_idx += 1 # Create handlers for sparse float columns. @@ -655,7 +693,8 @@ class GradientBoostedDecisionTreeModel(object): l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=sparse_float_column_idx, + feature_column_group_id=constant_op.constant( + sparse_float_column_idx), epsilon=epsilon, num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( @@ -663,10 +702,11 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_float_values[sparse_float_column_idx], self._sparse_float_shapes[sparse_float_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, - init_stamp_token=init_stamp_token)) + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction)) fc_name_idx += 1 # Create handlers for sparse int columns. @@ -678,32 +718,20 @@ class GradientBoostedDecisionTreeModel(object): l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=sparse_int_column_idx, + feature_column_group_id=constant_op.constant( + sparse_int_column_idx), sparse_int_column=sparse_tensor.SparseTensor( self._sparse_int_indices[sparse_int_column_idx], self._sparse_int_values[sparse_int_column_idx], self._sparse_int_shapes[sparse_int_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, - init_stamp_token=init_stamp_token)) + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction)) fc_name_idx += 1 - # Create steps accumulator. - steps_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar(), - name="StepsAccumulator") - - # Create bias stats accumulator. - bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, - name="BiasAccumulator") - # Create ensemble stats variables. num_layer_examples = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), @@ -725,7 +753,23 @@ class GradientBoostedDecisionTreeModel(object): initial_value=array_ops.zeros([], dtypes.int64), name="active_layer", trainable=False) - + # Variable that becomes false once bias centering is done. + continue_centering = variables.Variable( + initial_value=self._center_bias, + name="continue_centering", + trainable=False) + # Create bias stats accumulator. + bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, + name="BiasAccumulator") + # Create steps accumulator. + steps_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.scalar(), + hessian_shape=tensor_shape.scalar(), + name="StepsAccumulator") # Create ensemble stats summaries. summary.scalar("layer_stats/num_examples", num_layer_examples) summary.scalar("layer_stats/num_steps", num_layer_steps) @@ -734,16 +778,13 @@ class GradientBoostedDecisionTreeModel(object): # Update bias stats. stats_update_ops = [] - continue_centering = variables.Variable( - initial_value=self._center_bias, - name="continue_centering", - trainable=False) + stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn(ensemble_stamp, predictions, - gradients, bias_stats_accumulator), - control_flow_ops.no_op)) + self._make_update_bias_stats_fn( + ensemble_stamp, predictions, gradients, + bias_stats_accumulator), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -800,8 +841,8 @@ class GradientBoostedDecisionTreeModel(object): lambda: active_handlers)) # Prepare empty gradients and hessians when handlers are not ready. - empty_hess_shape = [1] + hessian_shape.as_list() - empty_grad_shape = [1] + gradient_shape.as_list() + empty_hess_shape = [1] + self._hessian_shape.as_list() + empty_grad_shape = [1] + self._gradient_shape.as_list() empty_gradients = constant_op.constant( [], dtype=dtypes.float32, shape=empty_grad_shape) @@ -823,34 +864,86 @@ class GradientBoostedDecisionTreeModel(object): per_handler_updates, ensemble_stamp, worker_device) for update in update_results.values(): stats_update_ops += update + + training_state = GBDTTrainingState( + num_layer_examples=num_layer_examples, + num_layer_steps=num_layer_steps, + num_layers=num_layers, + active_tree=active_tree, + active_layer=active_layer, + continue_centering=continue_centering, + bias_stats_accumulator=bias_stats_accumulator, + steps_accumulator=steps_accumulator, + handlers=handlers) + + reset_op = control_flow_ops.no_op() + if self._is_chief: + # Advance the ensemble stamp to throw away staggered workers. + stamp_token, _ = model_ops.tree_ensemble_serialize(self._ensemble_handle) + next_stamp_token = stamp_token + 1 + + reset_ops = [] + for handler in handlers: + reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0)) + if self._center_bias: + reset_ops.append( + bias_stats_accumulator.flush(stamp_token, next_stamp_token)) + reset_ops.append(steps_accumulator.flush(stamp_token, next_stamp_token)) + reset_ops.append(self._finalized_trees.assign(0).op) + reset_ops.append(self._attempted_trees.assign(0).op) + reset_ops.append( + model_ops.tree_ensemble_deserialize( + self._ensemble_handle, + stamp_token=next_stamp_token, + tree_ensemble_config="", + name="reset_gbdt")) + + reset_op = control_flow_ops.group([reset_ops]) + + return stats_update_ops, reset_op, training_state + + def increment_step_counter_and_maybe_update_ensemble(self, predictions_dict, + training_state): + """Increments number of visited examples and grows the ensemble. + + If the number of visited examples reaches the target examples_per_layer, + ensemble is updated. + + Args: + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + training_state: `dict` returned by update_stats. + + Returns: + An op that updates the counters and potientially grows the ensemble. + """ + batch_size = math_ops.cast( + array_ops.shape(predictions_dict[PREDICTIONS])[0], dtypes.float32) + ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] # Accumulate a step after updating stats. - batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) - with ops.control_dependencies(stats_update_ops): - add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]], - [batch_size], [1.0]) - # Determine learning rate. - learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( - "tuner") - if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": - tuner = getattr(self._learner_config.learning_rate_tuner, - learning_rate_tuner) - learning_rate = tuner.learning_rate - else: - # TODO(nponomareva, soroush) do the line search. - raise ValueError("Line search learning rate is not yet supported.") + steps_accumulator = training_state.steps_accumulator + num_layer_examples = training_state.num_layer_examples + num_layer_steps = training_state.num_layer_steps + active_layer = training_state.active_layer + add_step_op = steps_accumulator.add( + ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) # After adding the step, decide if further processing is needed. ensemble_update_ops = [add_step_op] + class_id = self._get_class_id(predictions_dict) + with ops.control_dependencies([add_step_op]): if self._is_chief: dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED] # Get accumulated steps and examples for the current layer. - _, _, _, _, acc_examples, acc_steps = steps_accumulator.serialize() + _, _, _, _, acc_examples, acc_steps = ( + steps_accumulator.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) - ensemble_update_ops.append(num_layer_examples.assign(acc_examples)) + ensemble_update_ops.append( + num_layer_examples.assign(acc_examples)) ensemble_update_ops.append(num_layer_steps.assign(acc_steps)) # Determine whether we need to update tree ensemble. examples_per_layer = self._examples_per_layer @@ -859,18 +952,172 @@ class GradientBoostedDecisionTreeModel(object): ensemble_update_ops.append( control_flow_ops.cond( acc_examples >= examples_per_layer, - self._make_update_ensemble_fn( - ensemble_stamp, steps_accumulator, bias_stats_accumulator, - continue_centering, learning_rate, handlers, num_layers, - active_tree, active_layer, dropout_seed, class_id), + self.make_update_ensemble_fn(ensemble_stamp, training_state, + dropout_seed, class_id), control_flow_ops.no_op)) - # Calculate the loss to be reported. # Note, the loss is calculated from the prediction considering dropouts, so # that the value might look staggering over steps when the dropout ratio is # high. eval_loss might be referred instead in the aspect of convergence. return control_flow_ops.group(*ensemble_update_ops) + def make_update_ensemble_fn(self, ensemble_stamp, training_state, + dropout_seed, class_id): + """A method to create the function which updates the tree ensemble.""" + # Determine learning rate. + learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( + "tuner") + if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": + tuner = getattr(self._learner_config.learning_rate_tuner, + learning_rate_tuner) + learning_rate = tuner.learning_rate + else: + # TODO(nponomareva, soroush) do the line search. + raise ValueError("Line search learning rate is not yet supported.") + + def _update_ensemble(): + """A method to update the tree ensemble.""" + # Get next stamp token. + next_ensemble_stamp = ensemble_stamp + 1 + # Finalize bias stats. + _, _, _, bias_grads, bias_hess = ( + training_state.bias_stats_accumulator.flush(ensemble_stamp, + next_ensemble_stamp)) + + # Finalize handler splits. + are_splits_ready_list = [] + partition_ids_list = [] + gains_list = [] + split_info_list = [] + + for handler in training_state.handlers: + (are_splits_ready, + partition_ids, gains, split_info) = handler.make_splits( + ensemble_stamp, next_ensemble_stamp, class_id) + are_splits_ready_list.append(are_splits_ready) + partition_ids_list.append(partition_ids) + gains_list.append(gains) + split_info_list.append(split_info) + # Stack all the inputs to one tensor per type. + # This is a workaround for the slowness of graph building in tf.cond. + # See (b/36554864). + split_sizes = array_ops.reshape( + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) + partition_ids = array_ops.concat(partition_ids_list, axis=0) + gains = array_ops.concat(gains_list, axis=0) + split_infos = array_ops.concat(split_info_list, axis=0) + + # Determine if all splits are ready. + are_all_splits_ready = math_ops.reduce_all( + array_ops.stack( + are_splits_ready_list, axis=0, name="stack_handler_readiness")) + + # Define bias centering update operation. + def _center_bias_fn(): + # Center tree ensemble bias. + delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, + array_ops.zeros_like(bias_grads)) + center_bias = training_ops.center_tree_ensemble_bias( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + delta_updates=delta_updates, + learner_config=self._learner_config_serialized) + return training_state.continue_centering.assign(center_bias) + + # Define ensemble growing operations. + def _grow_ensemble_ready_fn(): + # Grow the ensemble given the current candidates. + sizes = array_ops.unstack(split_sizes) + partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) + gains_list = list(array_ops.split(gains, sizes, axis=0)) + split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) + return training_ops.grow_tree_ensemble( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + learning_rate=learning_rate, + partition_ids=partition_ids_list, + gains=gains_list, + splits=split_info_list, + learner_config=self._learner_config_serialized, + dropout_seed=dropout_seed, + center_bias=self._center_bias) + + def _grow_ensemble_not_ready_fn(): + # Don't grow the ensemble, just update the stamp. + return training_ops.grow_tree_ensemble( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + learning_rate=0, + partition_ids=[], + gains=[], + splits=[], + learner_config=self._learner_config_serialized, + dropout_seed=dropout_seed, + center_bias=self._center_bias) + + def _grow_ensemble_fn(): + # Conditionally grow an ensemble depending on whether the splits + # from all the handlers are ready. + return control_flow_ops.cond(are_all_splits_ready, + _grow_ensemble_ready_fn, + _grow_ensemble_not_ready_fn) + + # Update ensemble. + update_ops = [are_all_splits_ready] + if self._center_bias: + update_model = control_flow_ops.cond(training_state.continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() + update_ops.append(update_model) + + # Update ensemble stats. + with ops.control_dependencies([update_model]): + stats = training_ops.tree_ensemble_stats( + self._ensemble_handle, stamp_token=next_ensemble_stamp) + update_ops.append(self._finalized_trees.assign(stats.num_trees)) + update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) + update_ops.append(training_state.num_layers.assign(stats.num_layers)) + update_ops.append(training_state.active_tree.assign(stats.active_tree)) + update_ops.append( + training_state.active_layer.assign(stats.active_layer)) + + # Flush step stats. + update_ops.extend( + training_state.steps_accumulator.flush(ensemble_stamp, + next_ensemble_stamp)) + return control_flow_ops.group(*update_ops, name="update_ensemble") + + return _update_ensemble + + def get_number_of_trees_tensor(self): + return self._finalized_trees, self._attempted_trees + + def train(self, loss, predictions_dict, labels): + """Updates the accumalator stats and grows the ensemble. + + Args: + loss: A scalar tensor representing average loss of examples. + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + labels: Rank 2 `Tensor` representing labels per example. Has no effect + on the training and is only kept for backward compatibility. + + Returns: + An op that adds a new tree to the ensemble. + + Raises: + ValueError: if inputs are not valid. + """ + del labels # unused; kept for backward compatibility. + update_op, _, training_state = self.update_stats(loss, predictions_dict) + with ops.control_dependencies(update_op): + return self.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + def _get_weights(self, hessian_shape, hessians): """Derives weights to be used based on hessians and multiclass strategy.""" if hessian_shape == tensor_shape.scalar(): @@ -986,127 +1233,3 @@ class GradientBoostedDecisionTreeModel(object): return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") return _update_bias_stats - - def _make_update_ensemble_fn(self, ensemble_stamp, steps_accumulator, - bias_stats_accumulator, continue_centering, - learning_rate, handlers, num_layers, active_tree, - active_layer, dropout_seed, class_id): - """A method to create the function which updates the tree ensemble.""" - - def _update_ensemble(): - """A method to update the tree ensemble.""" - # Get next stamp token. - next_ensemble_stamp = ensemble_stamp + 1 - # Finalize bias stats. - _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush( - ensemble_stamp, next_ensemble_stamp) - - # Finalize handler splits. - are_splits_ready_list = [] - partition_ids_list = [] - gains_list = [] - split_info_list = [] - - for handler in handlers: - (are_splits_ready, - partition_ids, gains, split_info) = handler.make_splits( - ensemble_stamp, next_ensemble_stamp, class_id) - are_splits_ready_list.append(are_splits_ready) - partition_ids_list.append(partition_ids) - gains_list.append(gains) - split_info_list.append(split_info) - # Stack all the inputs to one tensor per type. - # This is a workaround for the slowness of graph building in tf.cond. - # See (b/36554864). - split_sizes = array_ops.reshape( - array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) - partition_ids = array_ops.concat(partition_ids_list, axis=0) - gains = array_ops.concat(gains_list, axis=0) - split_infos = array_ops.concat(split_info_list, axis=0) - - # Determine if all splits are ready. - are_all_splits_ready = math_ops.reduce_all( - array_ops.stack( - are_splits_ready_list, axis=0, name="stack_handler_readiness")) - - # Define bias centering update operation. - def _center_bias_fn(): - # Center tree ensemble bias. - delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, - array_ops.zeros_like(bias_grads)) - center_bias = training_ops.center_tree_ensemble_bias( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - delta_updates=delta_updates, - learner_config=self._learner_config_serialized) - return continue_centering.assign(center_bias) - - # Define ensemble growing operations. - def _grow_ensemble_ready_fn(): - # Grow the ensemble given the current candidates. - sizes = array_ops.unstack(split_sizes) - partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) - gains_list = list(array_ops.split(gains, sizes, axis=0)) - split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) - return training_ops.grow_tree_ensemble( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - learning_rate=learning_rate, - partition_ids=partition_ids_list, - gains=gains_list, - splits=split_info_list, - learner_config=self._learner_config_serialized, - dropout_seed=dropout_seed, - center_bias=self._center_bias) - - def _grow_ensemble_not_ready_fn(): - # Don't grow the ensemble, just update the stamp. - return training_ops.grow_tree_ensemble( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - learning_rate=0, - partition_ids=[], - gains=[], - splits=[], - learner_config=self._learner_config_serialized, - dropout_seed=dropout_seed, - center_bias=self._center_bias) - - def _grow_ensemble_fn(): - # Conditionally grow an ensemble depending on whether the splits - # from all the handlers are ready. - return control_flow_ops.cond(are_all_splits_ready, - _grow_ensemble_ready_fn, - _grow_ensemble_not_ready_fn) - - # Update ensemble. - update_ops = [are_all_splits_ready] - if self._center_bias: - update_model = control_flow_ops.cond(continue_centering, - _center_bias_fn, _grow_ensemble_fn) - else: - update_model = _grow_ensemble_fn() - update_ops.append(update_model) - - # Update ensemble stats. - with ops.control_dependencies([update_model]): - stats = training_ops.tree_ensemble_stats( - self._ensemble_handle, stamp_token=next_ensemble_stamp) - update_ops.append(self._finalized_trees.assign(stats.num_trees)) - update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) - update_ops.append(num_layers.assign(stats.num_layers)) - update_ops.append(active_tree.assign(stats.active_tree)) - update_ops.append(active_layer.assign(stats.active_layer)) - - # Flush step stats. - update_ops.extend( - steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp)) - return control_flow_ops.group(*update_ops, name="update_ensemble") - - return _update_ensemble - - def get_number_of_trees_tensor(self): - return self._finalized_trees, self._attempted_trees diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index e3d4397fadcbaf148f7f6cfaca13e850639786cf..f7867d882d6813a8701065ad0ce8d27f8bb9c301 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.layers.python.layers import feature_column as feature_co from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -1560,6 +1561,301 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEquals(output.growing_metadata.num_layers_attempted, 2) + def testResetModelBeforeAndAfterSplit(self): + """Tests whether resetting works.""" + with self.test_session(): + # First build a small tree and train it to verify training works. + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + "max_tree_depth": 4, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create train op. + update_op, reset_op, training_state = gbdt_model.update_stats( + loss, predictions_dict) + with ops.control_dependencies(update_op): + train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + original_stamp = ensemble_stamp.eval() + expected_tree = """ + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + }""" + + def _train_once_and_check(expect_split): + stamp = ensemble_stamp.eval() + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(stamp_token.eval(), stamp + 1) + if expect_split: + # State of the ensemble after a split occurs. + self.assertEquals(len(output.trees), 1) + self.assertProtoEquals(expected_tree, output.trees[0]) + else: + # State of the ensemble after a single accumulation but before any + # splitting occurs + self.assertEquals(len(output.trees), 0) + self.assertProtoEquals(""" + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + }""", output) + + def _run_reset(): + stamp_before_reset = ensemble_stamp.eval() + reset_op.run() + stamp_after_reset = ensemble_stamp.eval() + self.assertNotEquals(stamp_after_reset, stamp_before_reset) + + _, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertProtoEquals("", output) + + return stamp_after_reset + + # Exit after one train_op, so no new layer are created but the handlers + # contain enough information to split on the next call to train. + _train_once_and_check(expect_split=False) + self.assertEquals(ensemble_stamp.eval(), original_stamp + 1) + + # Reset the handlers so it still requires two training calls to split. + stamp_after_reset = _run_reset() + + _train_once_and_check(expect_split=False) + _train_once_and_check(expect_split=True) + self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2) + + # This time, test that the reset_op works right after splitting. + stamp_after_reset = _run_reset() + + # Test that after resetting, the tree can be trained as normal. + _train_once_and_check(expect_split=False) + _train_once_and_check(expect_split=True) + self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2) + + def testResetModelNonChief(self): + """Tests the reset function on a non-chief worker.""" + with self.test_session(): + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + leaf { + vector { + value: 0.25 + } + } + } + } + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: false + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create reset op. + _, reset_op, _ = gbdt_model.update_stats( + loss, predictions_dict) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # Reset op doesn't do anything because this is a non-chief worker. + reset_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertEquals(len(output.tree_weights), 1) + self.assertEquals(stamp_token.eval(), 0) + + def testResetModelWithCenterBias(self): + """Tests the reset function running on chief with bias centering.""" + with self.test_session(): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create train op. + update_op, reset_op, training_state = gbdt_model.update_stats( + loss, predictions_dict) + with ops.control_dependencies(update_op): + train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect bias to be centered. + def train_and_check(): + train_op.run() + _, serialized = model_ops.tree_ensemble_serialize(ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + expected_tree = """ + nodes { + leaf { + vector { + value: 0.25 + } + } + }""" + self.assertEquals(len(output.trees), 1) + self.assertAllEqual(output.tree_weights, [1.0]) + self.assertProtoEquals(expected_tree, output.trees[0]) + + train_and_check() + self.assertEquals(ensemble_stamp.eval(), 1) + + reset_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 2) + + train_and_check() + self.assertEquals(ensemble_stamp.eval(), 3) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 9aa4614967958247dde5d81b862baaafd8d4144a..8c1ce5c2a2d552e30d3b676e3ac8b5fc7c74a917 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -22,6 +22,7 @@ Visualization and inspection: Managing dependencies: @@capture_dependencies @@Checkpointable +@@CheckpointableBase @@CheckpointableObjectGraph @@NoDependency @@split_dependency @@ -40,10 +41,11 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph -from tensorflow.python.training.checkpointable.base import Checkpointable -from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.tracking import Checkpointable +from tensorflow.python.training.checkpointable.tracking import NoDependency from tensorflow.python.training.checkpointable.util import capture_dependencies from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata @@ -51,4 +53,3 @@ from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) - diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 3717d7f583ffdc205a279d45df60cddbc5cbf08e..64d056bd689a14c0c58d7a0f75c833c71b00a5c3 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,13 +26,13 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util class UniqueNameTrackerTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNames(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -48,11 +48,11 @@ class UniqueNameTrackerTests(test.TestCase): slots.track(y, "y") self.evaluate((x1.initializer, x2.initializer, x3.initializer, y.initializer)) - save_root = checkpointable_utils.Checkpoint(slots=slots) + save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = checkpointable.Checkpointable() - restore_root = checkpointable_utils.Checkpoint( + restore_slots = tracking.Checkpointable() + restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) restore_slots.x = resource_variable_ops.ResourceVariable(0.) @@ -65,9 +65,9 @@ class UniqueNameTrackerTests(test.TestCase): self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) self.assertEqual(5., self.evaluate(restore_slots.y)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(checkpointable.Checkpointable): + class SlotManager(tracking.Checkpointable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() @@ -83,11 +83,11 @@ class UniqueNameTrackerTests(test.TestCase): manager = SlotManager() self.evaluate([v.initializer for v in manager.slots]) - checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint = util.Checkpoint(slot_manager=manager) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = checkpoint.save(checkpoint_prefix) - metadata = checkpointable_utils.object_metadata(save_path) + metadata = util.object_metadata(save_path) dependency_names = [] for node in metadata.nodes: for child in node.children: @@ -97,7 +97,7 @@ class UniqueNameTrackerTests(test.TestCase): dependency_names, ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayers(self): tracker = containers.UniqueNameTracker() tracker.track(layers.Dense(3), "dense") diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 69dc0b9be2d5548852c37552a64a0d31c9557b43..00a805af25d5d0ea723db5d015fb12bf45c53857 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,8 +23,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util def _split_variable_closure(variable): @@ -43,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): +class SaveTensorSlicesAsDeps(base.CheckpointableBase): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -58,14 +59,14 @@ class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): self._track_checkpointable(dep, name=name) -class HasRegularDeps(checkpointable.Checkpointable): +class HasRegularDeps(tracking.Checkpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(checkpointable.Checkpointable): +class OnlyOneDep(tracking.Checkpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) @@ -73,9 +74,9 @@ class OnlyOneDep(checkpointable.Checkpointable): class SplitTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestoreSplitDep(self): - save_checkpoint = checkpointable_utils.Checkpoint( + save_checkpoint = util.Checkpoint( dep=SaveTensorSlicesAsDeps()) self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) checkpoint_directory = self.get_temp_dir() @@ -83,7 +84,7 @@ class SplitTests(test.TestCase): save_path = save_checkpoint.save(checkpoint_prefix) regular_deps = HasRegularDeps() - regular_restore_checkpoint = checkpointable_utils.Checkpoint( + regular_restore_checkpoint = util.Checkpoint( dep=regular_deps) regular_restore_checkpoint.restore( save_path).assert_consumed().run_restore_ops() @@ -91,7 +92,7 @@ class SplitTests(test.TestCase): self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) one_dep = OnlyOneDep() - one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + one_dep_restore_checkpoint = util.Checkpoint(dep=one_dep) status = one_dep_restore_checkpoint.restore(save_path) with self.assertRaises(AssertionError): # Missing the second dependency. @@ -99,7 +100,7 @@ class SplitTests(test.TestCase): status.run_restore_ops() self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) - restore_checkpoint = checkpointable_utils.Checkpoint() + restore_checkpoint = util.Checkpoint() status = restore_checkpoint.restore(save_path) restore_checkpoint.dep = SaveTensorSlicesAsDeps() status.assert_consumed().run_restore_ops() diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 1a7a3759baa4a5559b4b70ff4f7467c41da9111f..523a9efcf05f5d32589f6e1734f866bf8b4b9cdc 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -50,6 +50,7 @@ py_library( deps = [ ":gen_bigquery_reader_ops", ":gen_gcs_config_ops", + "//tensorflow/contrib/bigtable", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/README.md b/tensorflow/contrib/cloud/README.md new file mode 100644 index 0000000000000000000000000000000000000000..134ce057f4334096b4fbbec29cc85f0ea42c9f86 --- /dev/null +++ b/tensorflow/contrib/cloud/README.md @@ -0,0 +1,18 @@ +# Cloud # + +## BigTable ## + +[Google Cloud BigTable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + + + +## Cloud Storage (GCS) ## + +The Google Cloud Storage ops allow the user to configure the GCS File System. + + diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index ef7aa7624ce7b9b6480c4d088a2fb7678a7acc76..af81106a6848bfd8c91108b56c8150d47c3eb501 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -18,15 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long,wildcard-import +import os + +# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * -# pylint: enable=line-too-long,wildcard-import + +if os.name != 'nt': + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + +del os from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'BigQueryReader', + 'BigTable', + 'BigtableClient', 'BlockCacheParams', 'configure_colab_session', 'configure_gcs', diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py index 8c8c5acb31af69b4f738a13c6548cdd31947d71a..95e7e744d34391a511cdba7702aad369b8d9d9c0 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -120,13 +120,18 @@ class ConfigureGcsHook(training.SessionRunHook): def begin(self): if self._credentials: self._credentials_placeholder = array_ops.placeholder(dtypes.string) - self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_op = gen_gcs_config_ops.gcs_configure_credentials( self._credentials_placeholder) + else: + self._credentials_op = None + if self._block_cache: self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( max_cache_size=self._block_cache.max_bytes, block_size=self._block_cache.block_size, max_staleness=self._block_cache.max_staleness) + else: + self._block_cache_op = None def after_create_session(self, session, coord): del coord diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py index fc0c9948129116ac371c64fc01a96ecc6194e244..9b6c056d6c8adfa50b95aefb8e9740631327a572 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -29,6 +29,16 @@ class GcsConfigOpsTest(test.TestCase): with self.test_session() as sess: gcs_config_ops.configure_gcs(sess, block_cache=cfg) + def testConfigureGcsHook(self): + creds = {'client_id': 'fake_client', + 'refresh_token': 'fake_token', + 'client_secret': 'fake_secret', + 'type': 'authorized_user'} + hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) + hook.begin() + with self.test_session() as sess: + sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None + hook.after_create_session(sess, None) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index e524e9e7437b19e0d117fe7b85042e8154773a02..a0a5b0e00c1979ebf8850408785135b9ceac7d2a 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -299,17 +299,20 @@ include_directories( ${double_conversion_INCLUDE_DIR} ) -if(tensorflow_ENABLE_SSL_SUPPORT) - include(boringssl) - list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) - include_directories(${boringssl_INCLUDE_DIR}) -endif() if(tensorflow_ENABLE_GRPC_SUPPORT) + if(tensorflow_ENABLE_SSL_SUPPORT) + include(boringssl) + include_directories(${boringssl_INCLUDE_DIR}) + endif() include(grpc) + include_directories(${GRPC_INCLUDE_DIRS}) + # Place boringssl after grpc as grpc depends on boringssl. list(APPEND tensorflow_EXTERNAL_LIBRARIES ${grpc_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES grpc) - include_directories(${GRPC_INCLUDE_DIRS}) + if(tensorflow_ENABLE_SSL_SUPPORT) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) + endif() endif() if(tensorflow_ENABLE_JEMALLOC_SUPPORT) include(jemalloc) @@ -336,40 +339,14 @@ endif() # MKL Support if (tensorflow_ENABLE_MKL_SUPPORT) add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) - if (WIN32) - find_path(MKL_HOME_PLATFORM mkl - PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES windows) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS - ${MKL_HOME_PLATFORM}/mkl/lib/intel64 - ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt - ${MKL_HOME_PLATFORM}/compiler/lib/intel64 - ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib) - set(MKL_REDIST_DLL_DIRS - ${MKL_HOME_PLATFORM}/redist/intel64/mkl - ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt - ${MKL_HOME_PLATFORM}/redist/intel64/compiler) - list(APPEND tensorflow_EXTERNAL_LIBRARIES - mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64) - endif() - if (UNIX) - # Fix me: complete the path on linux - find_path(MKL_HOME_PLATFORM mkl - HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES linux) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS) # incompleted - set(MKL_REDIST_SO_DIRS) # incompleted - endif() - include_directories(${MKL_INCLUDE_DIRS}) - link_directories(${MKL_LINK_DIRS}) + include(mkl) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination) + include_directories(${mkl_INCLUDE_DIRS}) if (tensorflow_ENABLE_MKLDNN_SUPPORT) include(mkldnn) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination) include_directories(${mkldnn_INCLUDE_DIRS}) else (tensorflow_ENABLE_MKLDNN_SUPPORT) add_definitions(-DINTEL_MKL_ML) diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index 3c4bb01e24fd121c9d0fc3594cc25de37af0e8a1..fbb14b2515a656f1dfc0e3f63ac367e9b7738a23 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include) #set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src) set(boringssl_URL https://boringssl.googlesource.com/boringssl) -set(boringssl_TAG ee7aa02) +set(boringssl_TAG 7f8c553d7f4db0a6ce727f2986d41bf8fe8ec4bf) set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build) #set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so) set(boringssl_STATIC_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/mkl.cmake b/tensorflow/contrib/cmake/external/mkl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a172e3a41a283359b9a8c823ddcb2b1973b5b3cc --- /dev/null +++ b/tensorflow/contrib/cmake/external/mkl.cmake @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +# NOTE: Different from mkldnn.cmake, this file is meant to download mkl libraries +set(mkl_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include) +set(mkl_BIN_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/bin) +set(mkl_WIN mklml_win_2018.0.3.20180406.zip) # match for v0.14 +set(mkl_MAC mklml_mac_2018.0.3.20180406.tgz) +set(mkl_LNX mklml_lnx_2018.0.3.20180406.tgz) +set(mkl_TAG v0.14) +set(mkl_URL https://github.com/intel/mkl-dnn/releases) + +if (WIN32) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_WIN}) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.lib) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.lib) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.dll) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.dll) +elseif (UNIX) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_LNX}) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_gnu.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_intel.so) +elseif (APPLE) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_MAC}) + #TODO need more information +endif () + +ExternalProject_Add(mkl + PREFIX mkl + URL ${mkl_DOWNLOAD_URL} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "") + +# put mkl dynamic libraries in one bin directory +add_custom_target(mkl_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${mkl_BIN_DIRS} + DEPENDS mkl) + +add_custom_target(mkl_copy_shared_to_destination DEPENDS mkl_create_destination_dir) + +foreach(dll_file ${mkl_SHARED_LIBRARIES}) + add_custom_command(TARGET mkl_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dll_file} ${mkl_BIN_DIRS}) +endforeach() diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake index a639fdee367f060d4c8a79267803da6ffe3dc503..8123ee1f393ab8e3a52f13915ea2a65decc188d9 100644 --- a/tensorflow/contrib/cmake/external/mkldnn.cmake +++ b/tensorflow/contrib/cmake/external/mkldnn.cmake @@ -22,8 +22,11 @@ set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.dll) + set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release) else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.dll) endif() else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) @@ -31,6 +34,7 @@ endif() ExternalProject_Add(mkldnn PREFIX mkldnn + DEPENDS mkl GIT_REPOSITORY ${mkldnn_URL} GIT_TAG ${mkldnn_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" @@ -40,5 +44,11 @@ ExternalProject_Add(mkldnn CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DMKLINC:STRING=${MKL_INCLUDE_DIRS} + -DMKLINC:STRING=${mkl_INCLUDE_DIRS} ) + +# since mkldnn depends on mkl, copy the mkldnn.dll together with mklml.dll to mkl_bin_dirs +add_custom_target(mkldnn_copy_shared_to_destination DEPENDS mkldnn) + +add_custom_command(TARGET mkldnn_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${mkldnn_SHARED_LIBRARIES} ${mkl_BIN_DIRS}) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c..eba3bcfc79efe87d0a45c979c5accfa1b6511ed0 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) +set(nsync_TAG 1.20.0) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index ab464bc99a43138130bb2758ae28ecef29805c31..f56fb35a0f71250f00b84e5cf94a24682bda6c82 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +set(PROTOBUF_TAG v3.6.0) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 015cb73bbd93bb77f6748a364b263d99eb305c27..8ff6ebedab05d4c058ada33c29ce5ea9c5d18a96 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -35,6 +35,7 @@ tensorflow/python/keras tensorflow/python/keras/applications tensorflow/python/keras/datasets tensorflow/python/keras/engine +tensorflow/python/keras/estimator tensorflow/python/keras/layers tensorflow/python/keras/preprocessing tensorflow/python/keras/utils @@ -85,6 +86,8 @@ tensorflow/contrib/batching/python/ops tensorflow/contrib/bayesflow tensorflow/contrib/bayesflow/python tensorflow/contrib/bayesflow/python/ops +# tensorflow/contrib/bigtable/python +# tensorflow/contrib/bigtable/python/ops tensorflow/contrib/boosted_trees tensorflow/contrib/boosted_trees/estimator_batch tensorflow/contrib/boosted_trees/kernels @@ -115,8 +118,6 @@ tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization tensorflow/contrib/constrained_optimization/python -tensorflow/contrib/control_flow -tensorflow/contrib/control_flow/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util @@ -131,6 +132,7 @@ tensorflow/contrib/data tensorflow/contrib/data/kernels tensorflow/contrib/data/python tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/kernel_tests/serialization tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 2e0a2fcef4cbdc50f0521296c4a25a864dbd8b77..7a30eb94f54b18a2a517615a315e23e09e1170d0 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -36,16 +36,3 @@ add_dependencies( tf_cc_while_loop tf_core_lib tf_protos_cc) - -if(tensorflow_BUILD_PYTHON_BINDINGS) - add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" - ) - add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) -endif() diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index dac84ccb0dbf4848329e35a6e9bcf6213d8c0e55..872b016d2b6c1b8fb5875c9568a1b7b6201507c0 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -49,43 +49,48 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -if(NOT WIN32) - function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR) - if(NOT ARGN) - message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") - return() +function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR) + if(NOT ARGN) + message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") + return() + endif() + + set(${SRCS}) + set(${HDRS}) + foreach(FIL ${ARGN}) + set(ABS_FIL ${ROOT_DIR}/${FIL}) + get_filename_component(FIL_WE ${FIL} NAME_WE) + get_filename_component(FIL_DIR ${ABS_FIL} PATH) + file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) + + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h") + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") + + # We adust the path of the gRPC code generation accordingly. + if(WIN32) + set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/Release/grpc_cpp_plugin.exe) + else() + set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/grpc_cpp_plugin) endif() - set(${SRCS}) - set(${HDRS}) - foreach(FIL ${ARGN}) - set(ABS_FIL ${ROOT_DIR}/${FIL}) - get_filename_component(FIL_WE ${FIL} NAME_WE) - get_filename_component(FIL_DIR ${ABS_FIL} PATH) - file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) - - list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc") - list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h") - list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") - list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") - - add_custom_command( - OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" - COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} - ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin protoc-gen-grpc=${GRPC_BUILD}/grpc_cpp_plugin -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} - DEPENDS ${ABS_FIL} protobuf grpc - COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}" - VERBATIM ) - endforeach() - - set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) - set(${SRCS} ${${SRCS}} PARENT_SCOPE) - set(${HDRS} ${${HDRS}} PARENT_SCOPE) - endfunction() -endif() + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin=protoc-gen-grpc=${GRPC_PROTOC_PLUGIN_PATH} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} + DEPENDS ${ABS_FIL} protobuf grpc + COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}" + VERBATIM ) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() function(RELATIVE_PROTOBUF_TEXT_GENERATE_CPP SRCS HDRS ROOT_DIR) if(NOT ARGN) @@ -125,6 +130,7 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" + "${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) @@ -174,17 +180,14 @@ RELATIVE_PROTOBUF_TEXT_GENERATE_CPP(PROTO_TEXT_SRCS PROTO_TEXT_HDRS ${tensorflow_source_dir} ${tf_proto_text_srcs} ) -if(WIN32) - add_library(tf_protos_cc ${PROTO_SRCS} ${PROTO_HDRS}) -else() - file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/debug/*.proto" - ) - RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS - ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs} - ) - add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS}) -endif() +file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/tensorflow/core/debug/*.proto" + "${tensorflow_source_dir}/tensorflow/core/protobuf/master_service.proto" +) +RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS + ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs} +) +add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS}) ######################################################## # tf_core_lib library @@ -233,15 +236,6 @@ if(WIN32) list(APPEND tf_core_lib_srcs ${tf_core_platform_windows_srcs}) endif(WIN32) -if(tensorflow_ENABLE_SSL_SUPPORT) - # Cloud libraries require boringssl. - file(GLOB tf_core_platform_cloud_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.h" - "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.cc" - ) - list(APPEND tf_core_lib_srcs ${tf_core_platform_cloud_srcs}) -endif() - if (tensorflow_ENABLE_HDFS_SUPPORT) list(APPEND tf_core_platform_hdfs_srcs "${tensorflow_source_dir}/tensorflow/core/platform/hadoop/hadoop_file_system.cc" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 2d76bf530a2100b2afa80a16a5d64b6ec51ffc68..844f62649d970506f1b4b4c5718fab8d1f0856e1 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -134,14 +134,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs}) endif(tensorflow_BUILD_CONTRIB_KERNELS) -if(NOT tensorflow_ENABLE_SSL_SUPPORT) - # Cloud libraries require boringssl. - file(GLOB tf_core_kernels_cloud_srcs - "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h" - "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc" - ) +# Cloud libraries require curl and boringssl. +# Curl is not supported yet anyway so we remove for now. +file(GLOB tf_core_kernels_cloud_srcs + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h" + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc" +) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_cloud_srcs}) -endif() file(GLOB_RECURSE tf_core_kernels_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*test*.h" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 92446044892127284ecb8753a250b77cb2a5743a..e3b59001bcb4f081eb2db3443ee9ad714c822ac8 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -456,6 +456,18 @@ add_custom_command( COMMENT "Running SWIG to generate Python wrappers" VERBATIM ) +add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" +) +add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc + tf_python_protos_cc) + set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc" @@ -743,26 +755,65 @@ set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") file(WRITE "${api_init_list_file}" "${api_init_files}") # Run create_python_api.py to generate __init__.py files. -add_custom_command( - OUTPUT ${api_init_files} - DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - - # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "--package=tensorflow.python" - "--apiname=tensorflow" - "${api_init_list_file}" - - COMMENT "Generating __init__.py files for Python API." - WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" -) +### TODO +# In order to download and compile MKL/MKL-DNN automatically in cmake script, mkl-built libraries should be added to system path +# to be loaded by python executor. However `add_custom_command` has an issue with `COMMAND ${CMAKE_COMMAND} -E env PATH=`, where +# arguments of multiple paths (such as D:/;D:/mkl) will be parsed in to seperate string without semicolon and that command fail to +# recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue. +# To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem, +# and should be removed if the path issue can be resolved. +### + +if (tensorflow_ENABLE_MKL_SUPPORT) + # add mkl dist dlls to system path for python + # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths, + # so we have to specify only one path in it to work around the issue. We need this if/else + # to protect overwriting CUDA environments + set(PY_RUNTIME_ENV ${mkl_BIN_DIRS}) + add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + VERBATIM + ) +else (tensorflow_ENABLE_MKL_SUPPORT) + add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + ) +endif (tensorflow_ENABLE_MKL_SUPPORT) add_custom_target(tf_python_api SOURCES ${api_init_files}) add_dependencies(tf_python_api tf_python_ops) diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 38f40452b533fdc0dba6ac686a0ff43a2ef13cb8..fdf522f1fd90ffc64acbe82381ef57a389645d61 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -145,3 +145,8 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# mkl +if (tensorflow_ENABLE_MKL_SUPPORT) + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ + DESTINATION include/mkl) +endif (tensorflow_ENABLE_MKL_SUPPORT) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 9a37b681194d4ef82b27a0160dd969f733ecad67..6d634cb1709910f366c7ca538d28bd802b2a7c63 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -64,8 +64,6 @@ file(GLOB tf_stream_executor_srcs if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" ) if (NOT tensorflow_BUILD_CC_TESTS) file(GLOB tf_stream_executor_gpu_tests @@ -76,11 +74,11 @@ if (tensorflow_ENABLE_GPU) list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) endif() -#file(GLOB_RECURSE tf_stream_executor_test_srcs -# "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.cc" -# "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.h" -#) -#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) +file(GLOB_RECURSE tf_stream_executor_test_srcs + "${tensorflow_source_dir}/tensorflow/stream_executor/*test.cc" + "${tensorflow_source_dir}/tensorflow/stream_executor/lib/*test.h" +) +list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) if (NOT WIN32) set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lgomp") diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index c8de8db126f7724386be565aa524b4b527976730..eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -325,8 +325,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py" # b/71901810 # Broken io_utils_test "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 - # OOM - "${tensorflow_source_dir}/tensorflow/python/training/saver_large_variable_test.py" # b/110210559 ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py index 0fbe3081af0b4de7f116918b3f49efe91a2d83bd..0c997bd4fdfa4233117c9fec2c4397301b1c8cb9 100644 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras import engine +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import init_ops @@ -40,7 +40,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary -class EntropyBottleneck(engine.Layer): +class EntropyBottleneck(base_layer.Layer): """Entropy bottleneck layer. This layer can be used to model the entropy (the amount of information @@ -262,7 +262,7 @@ class EntropyBottleneck(engine.Layer): self._range_coder_precision = int(range_coder_precision) self._data_format = data_format self._channel_axis(2) # trigger ValueError early - self.input_spec = engine.InputSpec(min_ndim=2) + self.input_spec = base_layer.InputSpec(min_ndim=2) @property def init_scale(self): @@ -357,7 +357,7 @@ class EntropyBottleneck(engine.Layer): channels = input_shape[channel_axis].value if channels is None: raise ValueError("The channel dimension of the inputs must be defined.") - self.input_spec = engine.InputSpec( + self.input_spec = base_layer.InputSpec( ndim=input_shape.ndims, axes={channel_axis: channels}) filters = (1,) + self.filters + (1,) scale = self.init_scale ** (1 / (len(self.filters) + 1)) diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD deleted file mode 100644 index e8036d63aeeac224b226899c036416a06b4ffe65..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/control_flow/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -# New implementations of control flow ops - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - -load("//tensorflow:tensorflow.bzl", "tf_py_test") - -py_library( - name = "control_flow", - srcs = ["__init__.py"], - srcs_version = "PY2AND3", - deps = [ - ":cond_v2", - ], -) - -py_library( - name = "cond_v2", - srcs = ["python/cond_v2.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:c_api_util", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:function_def_to_graph", - "//tensorflow/python:functional_ops_gen", - "//tensorflow/python:gradients", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:util", - ], -) - -tf_py_test( - name = "cond_v2_test", - size = "small", - srcs = ["python/cond_v2_test.py"], - additional_deps = [ - ":cond_v2", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:training", - ], - grpc_enabled = True, -) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 8285ea04926d3a24e9c22bd6d69eb7a48f5e3a85..252ea1560d7f5be3799686d6d91ae9a6d262ac0a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -768,7 +768,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLSTMCheckpointableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION @@ -781,7 +781,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGRUCheckpointableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION @@ -826,7 +826,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCudnnCompatibleLSTMCheckpointablMultiLayer(self): num_units = 2 num_layers = 3 diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 8822a7523f6b168f569e29970c9c29f2eb3614fc..748d7cd011f32fdebd781176b560b9b7498f327e 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import tracking as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 1af1ed08b53ee04367eb316d5c9caa0216f2e88d..156538b4e01bf1a1ccca0fca1e309b1d37b6dbc0 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -20,12 +20,15 @@ be used in conjunction with the @{tf.data.Dataset} API. Note that the guarantees as `tf.data`, but we will provide deprecation advice in advance of removing existing functionality. -See the @{$datasets$Importing Data} Programmer's Guide for an overview. +See @{$guide/datasets$Importing Data} for an overview. @@Counter @@CheckpointInputPipelineHook @@CsvDataset +@@RandomDataset +@@Reducer @@SqlDataset +@@TFRecordWriter @@assert_element_shape @@batch_and_drop_remainder @@ -33,11 +36,15 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset + +@@get_single_element +@@group_by_reducer @@group_by_window @@ignore_errors @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator + @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@ -50,8 +57,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@sliding_window_batch @@sloppy_interleave @@unbatch - -@@get_single_element +@@unique """ from __future__ import absolute_import @@ -71,13 +77,17 @@ from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length +from tensorflow.contrib.data.python.ops.grouping import group_by_reducer from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.grouping import Reducer +from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset @@ -87,6 +97,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch +from tensorflow.contrib.data.python.ops.unique import unique +from tensorflow.contrib.data.python.ops.writers import TFRecordWriter # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index a2bfce03620a1482f5b21cbf23c66833bc5cd480..b3d464d7165d53cf198072e06214f7d5e982073d 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -40,7 +40,8 @@ class FunctionBufferingResource : public ResourceBase { const NameAttrList& func, int64 buffer_size, const string& source_device, const string& target_device, - const std::vector& func_args) + const std::vector& func_args, + const DataTypeVector& output_types) : lib_(lib), pflr_(std::move(pflr)), func_(func), @@ -48,6 +49,7 @@ class FunctionBufferingResource : public ResourceBase { source_device_(source_device), target_device_(target_device), func_args_(func_args), + output_types_(output_types), handle_(kInvalidHandle), is_buffering_(false), end_of_sequence_(false), @@ -176,6 +178,13 @@ class FunctionBufferingResource : public ResourceBase { AllocatorAttributes arg_alloc_attr; arg_alloc_attr.set_on_host(true); opts.args_alloc_attrs.push_back(arg_alloc_attr); + for (const auto& dtype : output_types_) { + AllocatorAttributes ret_alloc_attrs; + if (DataTypeAlwaysOnHost(dtype)) { + ret_alloc_attrs.set_on_host(true); + } + opts.rets_alloc_attrs.push_back(ret_alloc_attrs); + } if (opts.source_device != target_device_) { opts.remote_execution = true; } @@ -233,6 +242,7 @@ class FunctionBufferingResource : public ResourceBase { const string source_device_; const string target_device_; const std::vector func_args_; + const DataTypeVector output_types_; FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::deque requests_ GUARDED_BY(mu_); @@ -250,6 +260,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); } ~FunctionBufferResourceHandleOp() override { @@ -269,18 +280,20 @@ class FunctionBufferResourceHandleOp : public OpKernel { std::vector func_args; func_args.push_back(*string_arg); + const string& source_device = ctx->device()->name(); + // Obtain and canonicalize target_device. const Tensor* target_arg; OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); - const string& target_device = - DeviceNameUtils::CanonicalizeDeviceName(target_arg->scalar()()); + string target_device; + OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName( + target_arg->scalar()(), source_device, + &target_device)); FunctionLibraryRuntime* lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library is provided.")); - const string& source_device = ctx->device()->name(); - mutex_lock l(mu_); if (!initialized_) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); @@ -297,7 +310,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { this](FunctionBufferingResource** ptr) { *ptr = new FunctionBufferingResource( clone_lib, std::move(pflr), func_, buffer_size_, - source_device, target_device, func_args); + source_device, target_device, func_args, output_types_); return Status::OK(); })); core::ScopedUnref s(buffer); @@ -319,6 +332,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { int64 buffer_size_; string container_; string name_; + DataTypeVector output_types_; }; REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 3dfc3741c2b040dd5be3223c24f0715ba3be4248..141706f393b076d9f55898ca4bdbe7438f7c3625 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace { @@ -24,19 +25,32 @@ namespace { class ThreadPoolResource : public ResourceBase { public: ThreadPoolResource(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads, bool low_latency_hint) - : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) { - } + const string& name, int num_threads, bool low_latency_hint, + int max_intra_op_parallelism) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint), + max_intra_op_parallelism_(max_intra_op_parallelism) {} // Schedules fn() for execution in the pool of threads. void Schedule(std::function fn) { - thread_pool_.Schedule(std::move(fn)); + if (max_intra_op_parallelism_ < 0) { + thread_pool_.Schedule(std::move(fn)); + } else { + thread_pool_.Schedule(std::bind( + [this](std::function bound_fn) { + // TODO(mrry): Consider moving this thread-local configuration to + // the threads themselves. + ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_); + bound_fn(); + }, + std::move(fn))); + } } string DebugString() override { return "ThreadPoolResource"; } private: thread::ThreadPool thread_pool_; + const int max_intra_op_parallelism_; }; // Creates a handle to a ThreadPool resource. Note that we don't use @@ -48,6 +62,8 @@ class ThreadPoolHandleOp : public OpKernel { explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", + &max_intra_op_parallelism_)); OP_REQUIRES( ctx, num_threads_ > 0, errors::InvalidArgument("`num_threads` must be greater than zero.")); @@ -78,7 +94,7 @@ class ThreadPoolHandleOp : public OpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new ThreadPoolResource( ctx->env(), {}, display_name_, - num_threads_, + num_threads_, max_intra_op_parallelism_, false /* low_latency_hint */); return Status::OK(); })); @@ -95,6 +111,7 @@ class ThreadPoolHandleOp : public OpKernel { bool initialized_ GUARDED_BY(mu_) = false; string display_name_; int num_threads_; + int max_intra_op_parallelism_; }; class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index f271d269ab1b9339de4657e459dcbbd462890f0a..8413fcaf872f49f654c6a1327a14d5c44bdd815a 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -104,6 +104,7 @@ REGISTER_OP("FunctionBufferingResource") .Attr("container: string") .Attr("f: func") .Attr("buffer_size: int") + .Attr("output_types: list(type)") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Creates a resource that fills up a buffer by making function calls. @@ -117,6 +118,7 @@ container: If non-empty, this resource is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this resource will be shared under the given name across multiple sessions. +output_types: The type list for the return values. )doc"); REGISTER_OP("FunctionBufferingResourceGetNext") @@ -158,6 +160,7 @@ REGISTER_OP("ThreadPoolHandle") .Output("handle: resource") .SetShapeFn(shape_inference::ScalarShape) .Attr("num_threads: int") + .Attr("max_intra_op_parallelism: int = 1") .Attr("display_name: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") @@ -166,6 +169,8 @@ Creates a custom thread pool with the given number of threads. handle: A resource that can be consumed by one or more ThreadPoolDataset ops. num_threads: The number of threads in the thread pool. +max_intra_op_parallelism: The maximum degree of parallelism to use within + operations that execute on this threadpool. display_name: A human-readable name for the threads that may be visible in some visualizations. )doc"); diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 4e3f9801d7144695478d7fcf2fbc9ecf6e57117a..d81654e039c53e5b9434288352ef1b2416a4b7e8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test") py_test( name = "batch_dataset_op_test", @@ -16,19 +16,21 @@ py_test( "no_pip", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -40,7 +42,6 @@ py_test( srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -49,37 +50,33 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", - "//third_party/py/numpy", - ], -) - -py_test( - name = "cache_dataset_op_test", - size = "small", - srcs = ["cache_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) py_test( - name = "concatenate_dataset_op_test", + name = "csv_dataset_op_test", size = "small", - srcs = ["concatenate_dataset_op_test.py"], + srcs = ["csv_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -94,104 +91,44 @@ py_test( "nomac", # b/62040583 ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", ], ) -py_library( - name = "dataset_serialization_test", - srcs = [ - "dataset_serialization_test_base.py", - ], +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "csv_dataset_op_test", - size = "small", - srcs = ["csv_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:error_ops", - "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:random_seed", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "filter_dataset_op_test", + name = "get_single_element_test", size = "small", - srcs = ["filter_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "optonly", - ], + srcs = ["get_single_element_test.py"], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:functional_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "flat_map_dataset_op_test", - size = "medium", - srcs = ["flat_map_dataset_op_test.py"], - additional_deps = [ - ":dataset_serialization_test", - "//third_party/py/numpy", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:get_single_element", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:function", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", ], - grpc_enabled = True, - tags = ["no_pip"], ) py_test( @@ -206,10 +143,8 @@ py_test( "notap", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -217,43 +152,28 @@ py_test( "//tensorflow/python:script_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", + "@six_archive//:six", ], ) py_test( - name = "directed_interleave_dataset_test", - size = "medium", - srcs = ["directed_interleave_dataset_test.py"], + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "get_single_element_test", - size = "small", - srcs = ["get_single_element_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:get_single_element", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) @@ -268,27 +188,13 @@ py_test( "optonly", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -300,23 +206,30 @@ py_test( srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:platform", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", ], ) -py_test( - name = "prefetch_dataset_op_test", +cuda_py_test( + name = "prefetching_ops_test", size = "small", - srcs = ["prefetch_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:platform", + srcs = ["prefetching_ops_test.py"], + additional_deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -326,20 +239,13 @@ py_test( srcs = ["range_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:counter", "//tensorflow/contrib/data/python/ops:enumerate_ops", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -351,15 +257,21 @@ py_library( "reader_dataset_ops_test_base.py", ], srcs_version = "PY2AND3", - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], deps = [ "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:util", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", ], ) @@ -368,24 +280,18 @@ py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], - shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:string_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], @@ -413,6 +319,7 @@ py_test( "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", + "@six_archive//:six", ], ) @@ -423,13 +330,14 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -437,57 +345,55 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "shuffle_dataset_op_test", size = "medium", - srcs = ["sequence_dataset_op_test.py"], + srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "serialization_integration_test", + name = "slide_dataset_op_test", size = "small", - srcs = ["serialization_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], + srcs = ["slide_dataset_op_test.py"], deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:sliding", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) -py_test( - name = "shuffle_dataset_op_test", - size = "medium", - srcs = ["shuffle_dataset_op_test.py"], +py_library( + name = "sql_dataset_op_test_base", + srcs = ["sql_dataset_op_test_base.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", + "@org_sqlite//:python", ], ) @@ -496,14 +402,12 @@ py_test( size = "small", srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", + ":sql_dataset_op_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "@org_sqlite//:python", ], ) @@ -514,7 +418,6 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/core:protos_all_py", @@ -537,8 +440,12 @@ py_test( "//tensorflow/contrib/data/python/ops:threadpool", "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -549,87 +456,27 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/contrib/stateless", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", ], ) py_test( - name = "zip_dataset_op_test", - size = "small", - srcs = ["zip_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -cuda_py_test( - name = "prefetching_ops_test", - size = "small", - srcs = ["prefetching_ops_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - ], -) - -tf_py_test( - name = "slide_dataset_op_test", - size = "small", - srcs = ["slide_dataset_op_test.py"], - additional_deps = [ - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:sliding", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//third_party/py/numpy", - ], -) - -tf_py_test( name = "writer_ops_test", size = "small", srcs = ["writer_ops_test.py"], - additional_deps = [ + deps = [ "//tensorflow/contrib/data/python/ops:writers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", "//tensorflow/python:lib", - "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 1435503beb96104c0a845bb064165099c680613a..af97fbf87aee5f7005f9d266ba9b1b6cf109a2ec 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -23,7 +23,6 @@ import time from absl.testing import parameterized import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops @@ -642,173 +641,79 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): "number of elements does not match"): sess.run(get_next) + def testMapAndBatchImplicitDispose(self): + # Tests whether a map and batch dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. + # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> + # MapAndBatchDataset(f=square_3, batch_size=100). + components = (np.arange(1000), + np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], + np.array(37.0) * np.arange(1000)) -class BatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) + dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( + 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) + dataset = dataset.prefetch(5) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() - return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + with self.test_session() as sess: + for _ in range(3): + sess.run(get_next) - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len // batch_size - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) + @parameterized.parameters(0, 5, 10, 90, 95, 99) + def testMapAndBatchOutOfRangeError(self, threshold): - def _build_dataset_dense_to_sparse(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) + def raising_py_fn(i): + if i >= threshold: + raise StopIteration() + else: + return i - def testDenseToSparseBatchDatasetCore(self): - components = np.random.randint(5, size=(40,)).astype(np.int32) - diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) - - num_outputs = len(components) // 4 - self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), - lambda: self._build_dataset_dense_to_sparse(diff_comp), - num_outputs) - - def _sparse(self, i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) + iterator = ( + dataset_ops.Dataset.range(100).apply( + batching.map_and_batch( + lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), + batch_size=10)).make_one_shot_iterator()) + get_next = iterator.get_next() + + with self.test_session() as sess: + for i in range(threshold // 10): + self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) + if threshold % 10 != 0: + self.assertAllEqual( + [threshold // 10 * 10 + j for j in range(threshold % 10)], + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) - def _build_dataset_sparse(self, batch_size=5): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) - - def testSparseCore(self): - self.run_core_tests(self._build_dataset_sparse, - lambda: self._build_dataset_sparse(2), 2) - - def _build_dataset_nested_sparse(self): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) - - def testNestedSparseCore(self): - self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + @parameterized.parameters( + (False, dtypes.bool), + (-42, dtypes.int8), + (-42, dtypes.int16), + (-42, dtypes.int32), + (-42, dtypes.int64), + (42, dtypes.uint8), + (42, dtypes.uint16), + (42.0, dtypes.float16), + (42.0, dtypes.float32), + (42.0, dtypes.float64), + (b"hello", dtypes.string), + ) + def testMapAndBatchTypes(self, element, dtype): + def gen(): + yield element + dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( + batching.map_and_batch(lambda x: x, batch_size=10)) -class UnbatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + get_next = dataset.make_one_shot_iterator().get_next() - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) - - return dataset_ops.Dataset.from_tensor_slices(components).batch( - batch_size).apply(batching.unbatch()) - - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) - - -class MapAndBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testNumParallelBatches(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_batches = 2 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_batches=num_parallel_batches, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - def testNumParallelCalls(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_calls = 7 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_calls=num_parallel_calls, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - -class PaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).padded_batch( - 4, padded_shapes=[-1]) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) + with self.test_session() as sess: + for _ in range(10): + self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) class RestructuredDatasetTest(test.TestCase): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 4fbfbfdbdd7ffd1019cef5bab7ffd5c149c37fcc..5fc7e51d814901985d33525b782434386c3ad18a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -21,7 +21,6 @@ import random import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -176,37 +175,27 @@ class GroupByReducerTest(test.TestCase): dataset.apply( grouping.group_by_reducer(lambda _: "wrong", reducer)) + def testTuple(self): + def init_fn(_): + return np.array([], dtype=np.int64), np.int64(0) -class GroupByReducerSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + def reduce_fn(state, value): + s1, s2 = state + v1, v2 = value + return array_ops.concat([s1, [v1]], 0), s2 + v2 - def _build_dataset(self, components): - reducer = grouping.Reducer( - init_func=lambda _: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) + def finalize_fn(s1, s2): + return s1, s2 - return dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_reducer(lambda x: x % 5, reducer)) - - def testCoreGroupByReducer(self): - components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 5, - verify_exhausted=True) + reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( + grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual(x, np.asarray([x for x in range(10)])) + self.assertEqual(y, 45) class GroupByWindowTest(test.TestCase): @@ -353,34 +342,6 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) -class GroupByWindowSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( - grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) - - def testCoreGroupByWindow(self): - components = np.array( - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 12, - verify_exhausted=False) - - # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 97b5e9416521dcad9ee5047a8275f8fd0142e338..df115175f5046803ada036563be1ca802f7ad0cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -76,7 +76,7 @@ class CsvDatasetOpTest(test.TestCase): filenames = self.setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( - lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + lambda l: parsing_ops.decode_csv(l, **kwargs)) dataset_actual = readers.CsvDataset(filenames, **kwargs) return (dataset_actual, dataset_expected) @@ -581,7 +581,7 @@ class CsvDatasetBenchmark(test.Benchmark): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() - dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') self._tearDown() @@ -591,7 +591,7 @@ class CsvDatasetBenchmark(test.Benchmark): num_cols = self._num_cols[i] kwargs = {'record_defaults': [['']] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() - dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') self._tearDown() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index a842502cc6fe3605dde0be5f50cf46e3e37d7ed4..a2ab3de52e8e512e3cba399f7a1725e5570cfd01 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,14 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -70,63 +66,5 @@ class DatasetConstructorTest(test.TestCase): # pylint: enable=protected-access -class DatasetConstructorSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_tensor_dataset(self, variable_array): - components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) - - return dataset_ops.Dataset.from_tensors(components) - - def testFromTensorsCore(self): - # Equal length components - arr = np.array(1) - num_outputs = 1 - diff_arr = np.array(2) - self.run_core_tests(lambda: self._build_tensor_dataset(arr), - lambda: self._build_tensor_dataset(diff_arr), - num_outputs) - - def _build_tensor_slices_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components) - - def testFromTensorSlicesCore(self): - # Equal length components - components = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0])) - - diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[5], [6], [7], [8]]), 22), - np.array([1.0, 2.0, 3.0, 4.0])) - - dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - - self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), - lambda: self._build_tensor_slices_dataset(diff_comp), 4) - self.run_core_tests( - lambda: self._build_tensor_slices_dataset(dict_components), None, 3) - - def _build_sparse_tensor_slice_dataset(self, slices): - indices = np.array( - [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], - dtype=np.int64) - values = np.array([val for s in slices for val in s], dtype=np.float64) - dense_shape = np.array( - [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) - sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) - return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) - - def testFromSparseTensorSlicesCore(self): - slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] - diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] - - self.run_core_tests( - lambda: self._build_sparse_tensor_slice_dataset(slices), - lambda: self._build_sparse_tensor_slice_dataset(diff_slices), - 9, - sparse_tensors=True) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 34b6a080c0aae7dfc228746139acc52cea4e6f28..9b1857de1a96c8f71788a1bf5085ef0605417fe7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -19,7 +19,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -34,8 +33,8 @@ class DirectedInterleaveDatasetTest(test.TestCase): input_datasets = [ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) + dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset, + input_datasets) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() @@ -144,24 +143,5 @@ class DirectedInterleaveDatasetTest(test.TestCase): ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, probs, num_samples): - dataset = interleave_ops.sample_from_datasets( - [ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(len(probs)) - ], - probs, - seed=1813) - return dataset.take(num_samples) - - def testSerializationCore(self): - self.run_core_tests( - lambda: self._build_dataset([0.5, 0.5], 100), - lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index bee561e3e23a2ab6f314894caa21785347e6ca8b..44c3325a3db84bb844b7f860a7c925982f1e3d6a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -22,10 +22,8 @@ import math import threading import time -import numpy as np from six.moves import zip_longest -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -38,132 +36,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, input_values, cycle_length, block_length): - repeat_count = 2 - return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( - repeat_count).interleave( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) - - def testSerializationCore(self): - input_values = np.array([4, 5, 6], dtype=np.int64) - num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), - num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # pylint: enable=g-long-lambda - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).interleave( - _interleave_fn, cycle_length=1) - - self.run_core_tests(_build_dataset, None, 20) - - -class ParallelInterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self.input_values = np.array([4, 5, 6], dtype=np.int64) - self.num_repeats = 2 - self.num_outputs = np.sum(self.input_values) * 2 - - def _build_ds(self, cycle_length, block_length, sloppy=False): - return (dataset_ops.Dataset.from_tensor_slices( - self.input_values).repeat(self.num_repeats).apply( - interleave_ops.parallel_interleave( - lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), - cycle_length, block_length, sloppy))) - - def testSerializationCore(self): - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - self.run_core_tests( - lambda: self._build_ds(cycle_length, block_length), - lambda: self._build_ds(cycle_length * 2, block_length * 1), - self.num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - - def testSerializationWithSloppy(self): - break_points = self.gen_break_points(self.num_outputs, 10) - expected_outputs = np.repeat( - np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), - self.num_repeats).tolist() - - def run_test(cycle_length, block_length): - actual = self.gen_outputs( - lambda: self._build_ds(cycle_length, block_length, True), - break_points, self.num_outputs) - self.assertSequenceEqual(sorted(actual), expected_outputs) - - # cycle_length > 1, block_length > 1 - run_test(2, 3) - # cycle_length = 1 - run_test(1, 3) - # block_length = 1 - run_test(2, 1) - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).apply( - interleave_ops.parallel_interleave(_interleave_fn, 1)) - - self.run_core_tests(_build_dataset, None, 20) - - class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py similarity index 100% rename from tensorflow/contrib/data/python/ops/iterator_ops_test.py rename to tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8d4042927970cab2f5a518fc0da49b38444dbcdf..270a2297b4d7b4fc44e3d1fa0aea8c9dfa5f39d3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -21,20 +21,12 @@ import os import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import error_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -143,229 +135,5 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) -class MapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 14 - self._num_outputs = self._tensor_slice_len * self._num_epochs - - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(self._num_epochs)) - - def testSaveRestoreCore(self): - self.run_core_tests( - self._build_ds, - lambda: self._build_ds(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map(_map_fn) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1))) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testSparseCore(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])) - - def _build_ds(num_outputs): - return dataset_ops.Dataset.range(num_outputs).map(_sparse) - - num_outputs = 10 - self.run_core_tests(lambda: _build_ds(num_outputs), - lambda: _build_ds(int(num_outputs / 2)), num_outputs) - - -class ParallelMapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 1 - self._num_outputs = self._tensor_slice_len * self._num_epochs - - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) - - def _build_ds_with_prefetch(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) - - def testSaveRestoreCore(self): - for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: - self.run_core_tests( - ds_fn, - lambda: ds_fn(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map( - _map_fn, num_parallel_calls=2).prefetch(2) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1), - num_parallel_calls=2).prefetch(2)) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - -class IgnoreErrorsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message")).apply( - error_ops.ignore_errors()) - - def testIgnoreErrorsCore(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) - num_outputs = 4 - self.run_core_tests(lambda: self._build_ds(components), - lambda: self._build_ds(diff_components), num_outputs) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index 30f1847dcddbfaf379ef2b09185f7a8db4aaeae2..e35be8a23f3706bd170c09b967b4f419fc9a626e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import optimization from tensorflow.core.framework import graph_pb2 from tensorflow.python.data.ops import dataset_ops @@ -73,17 +72,5 @@ class OptimizeDatasetTest(test.TestCase): sess.run(get_next) -class OptimizeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testCore(self): - - def build_dataset(num_elements, batch_size): - return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( - batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) - - self.run_core_tests(lambda: build_dataset(200, 10), None, 20) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index b08132cd72254326d965907a1fdafb8a820926a1..20ed6397505dbd77dbfe686147391c18b62c8718 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -68,6 +68,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): with ops.device(device1): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, + output_types=[dtypes.float32], target_device=target, string_arg=ds_iterator_handle, buffer_size=3, @@ -201,6 +202,49 @@ class PrefetchingKernelsOpsTest(test.TestCase): sess.run(destroy_op) + def testStringsGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/gpu:0" + + ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]) + ds_iterator = ds.make_one_shot_iterator() + ds_iterator_handle = ds_iterator.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, ds.output_types, ds.output_shapes) + return remote_iterator.get_next() + + target = constant_op.constant(device0) + with ops.device(device1): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_remote_fn, + output_types=[dtypes.string], + target_device=target, + string_arg=ds_iterator_handle, + buffer_size=3, + shared_name="strings") + + with ops.device(device1): + prefetch_op = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=buffer_resource_handle, + output_types=[dtypes.string]) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + with self.test_session() as sess: + self.assertEqual(["a"], sess.run(prefetch_op)) + self.assertEqual(["b"], sess.run(prefetch_op)) + self.assertEqual(["c"], sess.run(prefetch_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + class PrefetchToDeviceTest(test.TestCase): @@ -235,6 +279,36 @@ class PrefetchToDeviceTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device( + "/job:localhost/replica:0/task:0/device:CPU:0")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testPrefetchDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 80e1cb0041024b68bd5268b5de5d69c88c839896..592642da0cfd84e50cb20d9b2e534411faf927e8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -17,21 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -81,88 +73,5 @@ class RangeDatasetTest(test.TestCase): self.assertEqual(-2, sess.run(negative_get_next)) -class RangeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _iterator_checkpoint_prefix_local(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_prefix_local(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_prefix_local()), - dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def testSaveRestore(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Saving and restoring in same session. - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def _build_range_dataset(self, start, stop): - return dataset_ops.Dataset.range(start, stop) - - def testRangeCore(self): - start = 2 - stop = 10 - stop_1 = 8 - self.run_core_tests(lambda: self._build_range_dataset(start, stop), - lambda: self._build_range_dataset(start, stop_1), - stop - start) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 3b07ef290bc38daa37472ef8919f3350851fe370..9df403ef50e459d94b8edf3f651c7c95baf3ec42 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -17,266 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gzip import os -import zlib import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.lib.io import python_io -from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -from tensorflow.python.util import compat - - -class TextLineDatasetTestBase(test.TestCase): - - def _lineText(self, f, l): - return compat.as_bytes("%d: %d" % (f, l)) - - def _createFiles(self, - num_files, - num_lines, - crlf=False, - compression_type=None): - filenames = [] - for i in range(num_files): - fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) - filenames.append(fn) - contents = [] - for j in range(num_lines): - contents.append(self._lineText(i, j)) - # Always include a newline after the record unless it is - # at the end of the file, in which case we include it - if j + 1 != num_lines or i == 0: - contents.append(b"\r\n" if crlf else b"\n") - contents = b"".join(contents) - - if not compression_type: - with open(fn, "wb") as f: - f.write(contents) - elif compression_type == "GZIP": - with gzip.GzipFile(fn, "wb") as f: - f.write(contents) - elif compression_type == "ZLIB": - contents = zlib.compress(contents) - with open(fn, "wb") as f: - f.write(contents) - else: - raise ValueError("Unsupported compression_type", compression_type) - - return filenames - - -class TextLineDatasetSerializationTest( - TextLineDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, test_filenames, compression_type=None): - return core_readers.TextLineDataset( - test_filenames, compression_type=compression_type, buffer_size=10) - - def testTextLineCore(self): - compression_types = [None, "GZIP", "ZLIB"] - num_files = 5 - lines_per_file = 5 - num_outputs = num_files * lines_per_file - for compression_type in compression_types: - test_filenames = self._createFiles( - num_files, - lines_per_file, - crlf=True, - compression_type=compression_type) - # pylint: disable=cell-var-from-loop - self.run_core_tests( - lambda: self._build_iterator_graph(test_filenames, compression_type), - lambda: self._build_iterator_graph(test_filenames), num_outputs) - # pylint: enable=cell-var-from-loop - - -class FixedLengthRecordReaderTestBase(test.TestCase): - - def setUp(self): - super(FixedLengthRecordReaderTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - self._header_bytes = 5 - self._record_bytes = 3 - self._footer_bytes = 2 - - def _record(self, f, r): - return compat.as_bytes(str(f * 2 + r) * self._record_bytes) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with open(fn, "wb") as f: - f.write(b"H" * self._header_bytes) - for j in range(self._num_records): - f.write(self._record(i, j)) - f.write(b"F" * self._footer_bytes) - return filenames - - -class FixedLengthRecordDatasetSerializationTest( - FixedLengthRecordReaderTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, num_epochs, compression_type=None): - filenames = self._createFiles() - return core_readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, - self._footer_bytes).repeat(num_epochs) - - def testFixedLengthRecordCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - -class TFRecordDatasetTestBase(test.TestCase): - - def setUp(self): - super(TFRecordDatasetTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - - self.test_filenames = self._createFiles() - - self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) - self.num_epochs = array_ops.placeholder_with_default( - constant_op.constant(1, dtypes.int64), shape=[]) - self.compression_type = array_ops.placeholder_with_default("", shape=[]) - self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = core_readers.TFRecordDataset( - self.filenames, self.compression_type).repeat(self.num_epochs) - batch_dataset = repeat_dataset.batch(self.batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - self.init_op = iterator.make_initializer(repeat_dataset) - self.init_batch_op = iterator.make_initializer(batch_dataset) - self.get_next = iterator.get_next() - - def _record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - -class TFRecordDatasetSerializationTest( - TFRecordDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, - num_epochs, - batch_size=1, - compression_type=None, - buffer_size=None): - filenames = self._createFiles() - if compression_type is "ZLIB": - zlib_files = [] - for i, fn in enumerate(filenames): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) - filenames = zlib_files - - elif compression_type is "GZIP": - gzip_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(gzfn, "wb") as gzf: - gzf.write(f.read()) - gzip_files.append(gzfn) - filenames = gzip_files - - return core_readers.TFRecordDataset( - filenames, compression_type, - buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) - - def testTFRecordWithoutBufferCore(self): - num_epochs = 5 - batch_size = num_epochs - num_outputs = num_epochs * self._num_files * self._num_records // batch_size - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, batch_size, - buffer_size=0), - lambda: self._build_iterator_graph(num_epochs * 2, batch_size), - num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, - num_outputs * batch_size) - # pylint: enable=g-long-lambda - - def testTFRecordWithBufferCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - def testTFRecordWithCompressionCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - - -def _interleave(iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 class ReadBatchFeaturesTest( @@ -914,7 +668,30 @@ class MakeCsvDatasetTest(test.TestCase): self.assertFalse(all_equal) -class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): +class MakeTFRecordDatasetTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase): + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 def _next_expected_batch(self, file_indices, @@ -930,8 +707,8 @@ class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): yield j, i def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) record_batch = [] batch_index = 0 diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py index 805a7c7b7384d53cc166a48ba243502ef8643280..e63bc4c72049c61aa40314ffebe5c4366a818d46 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -12,24 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Base class for testing reader datasets.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gzip import os +import zlib from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat +class FixedLengthRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing FixedLengthRecordDataset.""" + + def setUp(self): + super(FixedLengthRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self._header_bytes = 5 + self._record_bytes = 3 + self._footer_bytes = 2 + + def _record(self, f, r): + return compat.as_bytes(str(f * 2 + r) * self._record_bytes) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + f.write(b"H" * self._header_bytes) + for j in range(self._num_records): + f.write(self._record(i, j)) + f.write(b"F" * self._footer_bytes) + return filenames + + class ReadBatchFeaturesTestBase(test.TestCase): """Base class for setting up and testing `make_batched_feature_dataset`.""" @@ -216,3 +249,83 @@ class ReadBatchFeaturesTestBase(test.TestCase): actual_batch = self._next_actual_batch(sess) for i in range(len(expected_batch)): self.assertAllEqual(expected_batch[i], actual_batch[i]) + + +class TextLineDatasetTestBase(test.TestCase): + """Base class for setting up and testing TextLineDataset.""" + + def _lineText(self, f, l): + return compat.as_bytes("%d: %d" % (f, l)) + + def _createFiles(self, + num_files, + num_lines, + crlf=False, + compression_type=None): + filenames = [] + for i in range(num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(num_lines): + contents.append(self._lineText(i, j)) + # Always include a newline after the record unless it is + # at the end of the file, in which case we include it + if j + 1 != num_lines or i == 0: + contents.append(b"\r\n" if crlf else b"\n") + contents = b"".join(contents) + + if not compression_type: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + + return filenames + + +class TFRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing TFRecordDataset.""" + + def setUp(self): + super(TFRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + self.test_filenames = self._createFiles() + + self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) + self.num_epochs = array_ops.placeholder_with_default( + constant_op.constant(1, dtypes.int64), shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = core_readers.TFRecordDataset( + self.filenames, self.compression_type).repeat(self.num_epochs) + batch_dataset = repeat_dataset.batch(self.batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + self.init_op = iterator.make_initializer(repeat_dataset) + self.init_batch_op = iterator.make_initializer(batch_dataset) + self.get_next = iterator.get_next() + + def _record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 520da7d6ff3ed50352a89c8a2d4f08122eb922dd..c5cfddb72b56a1bcffc80c0dd34994def3ee45cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from absl.testing import parameterized import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -import time from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index eb2ceff893543f710d4f0246adf4e6367a2deeb0..42cada0b97bcd9ab755297e8b1f0667766f7999e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,7 +21,6 @@ import itertools import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context @@ -64,7 +63,7 @@ class ScanDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFibonacci(self): iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) @@ -168,18 +167,5 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_elements): - return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( - scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) - - def testScanCore(self): - num_output = 5 - self.run_core_tests(lambda: self._build_dataset(num_output), - lambda: self._build_dataset(2), num_output) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..686788522acdf1c5e91132c38bdc81d10d2a0cc2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -0,0 +1,526 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "dataset_serialization_test_base", + srcs = [ + "dataset_serialization_test_base.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "batch_dataset_serialization_test", + size = "medium", + srcs = ["batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "cache_dataset_serialization_test", + size = "small", + srcs = ["cache_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "concatenate_dataset_serialization_test", + size = "small", + srcs = ["concatenate_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "dataset_constructor_serialization_test", + size = "medium", + srcs = ["dataset_constructor_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "filter_dataset_serialization_test", + size = "medium", + srcs = ["filter_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fixed_length_record_dataset_serialization_test", + size = "medium", + srcs = ["fixed_length_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "flat_map_dataset_serialization_test", + size = "medium", + srcs = ["flat_map_dataset_serialization_test.py"], + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "group_by_reducer_serialization_test", + size = "medium", + srcs = ["group_by_reducer_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "group_by_window_serialization_test", + size = "medium", + srcs = ["group_by_window_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "ignore_errors_serialization_test", + size = "small", + srcs = ["ignore_errors_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "interleave_dataset_serialization_test", + size = "medium", + srcs = ["interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "map_and_batch_dataset_serialization_test", + size = "medium", + srcs = ["map_and_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "map_dataset_serialization_test", + size = "medium", + srcs = ["map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_serialization_test", + size = "small", + srcs = ["optimize_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "padded_batch_dataset_serialization_test", + size = "medium", + srcs = ["padded_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:string_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_interleave_dataset_serialization_test", + size = "medium", + srcs = ["parallel_interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_map_dataset_serialization_test", + size = "medium", + srcs = ["parallel_map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "prefetch_dataset_serialization_test", + size = "small", + srcs = ["prefetch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "range_dataset_serialization_test", + size = "small", + srcs = ["range_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sample_from_datasets_serialization_test", + size = "medium", + srcs = ["sample_from_datasets_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "scan_dataset_serialization_test", + size = "small", + srcs = ["scan_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:scan_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sequence_dataset_serialization_test", + size = "medium", + srcs = ["sequence_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "serialization_integration_test", + size = "small", + srcs = ["serialization_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_and_repeat_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_and_repeat_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sql_dataset_serialization_test", + size = "small", + srcs = ["sql_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + ], +) + +py_test( + name = "stats_dataset_serialization_test", + size = "medium", + srcs = ["stats_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "textline_dataset_serialization_test", + size = "medium", + srcs = ["textline_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "tf_record_dataset_serialization_test", + size = "medium", + srcs = ["tf_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "unbatch_dataset_serialization_test", + size = "medium", + srcs = ["unbatch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "unique_dataset_serialization_test", + size = "small", + srcs = ["unique_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:unique", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "zip_dataset_serialization_test", + size = "small", + srcs = ["zip_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af87d8b6083de268fafd4346d2871f14e0f4e7c9 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py @@ -0,0 +1,83 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the BatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class BatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len // batch_size + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + def testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py similarity index 97% rename from tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py index f08216a303e2d7dee155ccadcdb9f42f1b24ea0f..a0a1100893c7384b0e2bd9fcfdaa8d3698b95d28 100644 --- a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental features of CacheDataset.""" +"""Tests for the CacheDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test -class CacheToFileDatasetSerializationTest( +class CacheDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py index 17f2980157ddd0350dafd1d745cbb9b64e65f7c5..96f13d75a31b6762b35062e6cf8c0cdb4d61d2c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ConcatenateDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2139b5c33db69a7ffbdebee74e5824928004b407 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the dataset constructors serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +class FromTensorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) + + return dataset_ops.Dataset.from_tensors(components) + + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) + + +class FromTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) + + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array([37.0, 38.0, 39.0, 40.0])) + + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + +class FromSparseTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py similarity index 100% rename from tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py index b572d6ed770fc0fe0f852359baf343c55966eddd..7c170078a11aadce9e5730437e4c25209bd58edb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FilterDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops @@ -35,7 +35,7 @@ class FilterDatasetSerializationTest( def testFilterCore(self): div = 3 - num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) + num_outputs = np.sum([x % 3 != 2 for x in range(100)]) self.run_core_tests(lambda: self._build_filter_range_graph(div), lambda: self._build_filter_range_graph(div * 2), num_outputs) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34392d88d4505175c4562e23d5f0c4116e00b022 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the FixedLengthRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class FixedLengthRecordDatasetSerializationTest( + reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): + filenames = self._createFiles() + return core_readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py index f3feecef32e587045be25056815315136a883ca7..16051ffd3fd1e1e7ff419f28109df7bc1f165257 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FlatMapDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..571e0899bbc1f856d66f85c4f6f3ac78aa0b1368 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py @@ -0,0 +1,61 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByReducer serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f86af4084ef61c2f20dbe2fb388a20287676f39d --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByWindow serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByWindowSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + + def testCoreGroupByWindow(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 12, + verify_exhausted=False) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..65ae9923b8f64dddcd54afc53e2fa67bc770fc2a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the IgnoreErrors input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) + + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3892fe81a1c0d325ddc5f501c2caed4b53f5d5 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the InterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class InterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length) + + def testSerializationCore(self): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length * 1), + num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # pylint: enable=g-long-lambda + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cd211328fa595c0ce0efe3509e8ba9dc06af80 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab783e5cce95ed63fe64c273abb3846121c7a274 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py @@ -0,0 +1,140 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testSparseCore(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _build_ds(num_outputs): + return dataset_ops.Dataset.range(num_outputs).map(_sparse) + + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c03495e34e73018bf9832bf77cdcf038449488 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the OptimizeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac42a461afcb6803a0e033892e74fb84d1e5e58 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the PaddedBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8a584df902180aa7ab020b47ecc749912a3a3a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelInterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb7605be1f230cef4cdae30aa672842a678edf7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelMapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class ParallelMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 1 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) + + def _build_ds_with_prefetch(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) + + def testSaveRestoreCore(self): + for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: + self.run_core_tests( + ds_fn, + lambda: ds_fn(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py similarity index 90% rename from tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py index 3d120a3071ef730f21221e3291d8c84385b51aa3..c802402461216de33e7d3232ba38063c27f33557 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the PrefetchDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f5b6cf5db788ad2fd09b7e93d0ae5ebb530a11 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the RangeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RangeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _iterator_checkpoint_prefix_local(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix_local(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix_local()), + dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def _build_range_dataset(self, start, stop): + return dataset_ops.Dataset.range(start, stop) + + def testRangeCore(self): + start = 2 + stop = 10 + stop_1 = 8 + self.run_core_tests(lambda: self._build_range_dataset(start, stop), + lambda: self._build_range_dataset(start, stop_1), + stop - start) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb35ea624c22ad0a9561d774c86247119c4c837 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SampleFromDatasets serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af9ef48c0f3b92f61c097410ef4dfd787292e76a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ScanDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ScanDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py index d0cb203a3afd2775756c8542a1e86faedc5cee53..2afebca0f5849c640044830fff05ebff131e0875 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the sequence datasets serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class SequenceDatasetSerializationTest( +class SkipDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_skip_dataset(self, count): @@ -52,6 +52,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + +class TakeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -79,6 +83,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + +class RepeatDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( @@ -117,5 +125,5 @@ class SequenceDatasetSerializationTest( None, 0) -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py index 0a6b74dc3eb80a6168117beed06935737198cecb..992d996a485de94ad55305552e42c7fbc92ec64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Integration test for input pipeline serialization.""" +"""Integration test for dataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -26,7 +26,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib -class MultipleInputPipelinesTest(test.TestCase): +class SerializationIntegrationTest(test.TestCase): def _build_input_pipeline(self, name, num_outputs): with ops.name_scope(name): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f199ec835ef1c72e2c3f8b3b1cc4f5fe6ea0b6f4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleAndRepeatDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d46c762aaaadc4314a10acc5aeb7ace7df5002a8 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class ShuffleDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_shuffle_dataset( + self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + ): + return dataset_ops.Dataset.range(range_limit).shuffle( + buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) + + def testShuffleCore(self): + + seed = 55 + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + # pylint: disable=cell-var-from-loop + # pylint: disable=g-long-lambda + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..93b26ed58a065de2074906528a0f49d696a813ff --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SqlDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetSerializationTest( + sql_dataset_op_test_base.SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..14cd3e9c4a72cc7832f9bb1cb49c72a8a7cb2dcd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the StatsDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def test_bytes_produced_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.bytes_produced_stats(["bytes_produced"])), + None, 100) + # pylint: enable=g-long-lambda + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def test_latency_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats(["record_latency", "record_latency_2"])), + None, 100) + # pylint: enable=g-long-lambda + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2483787f44f913199e3f2aa46d181d609a4a9a8f --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TextLineDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TextLineDatasetSerializationTest( + reader_dataset_ops_test_base.TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, test_filenames, compression_type=None): + return core_readers.TextLineDataset( + test_filenames, compression_type=compression_type, buffer_size=10) + + def testTextLineCore(self): + compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file + for compression_type in compression_types: + test_filenames = self._createFiles( + num_files, + lines_per_file, + crlf=True, + compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..55a6257a274cd7f78e3818943627cfa09a185fd7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TFRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TFRecordDatasetSerializationTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type == "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type == "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return core_readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a5a8a20dd7a9f891b07351570006636ca34bd0 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UnbatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UnbatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch( + batch_size).apply(batching.unbatch()) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..22f15b88464a770207dc7c6f0387d73ea3d5c2e4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UniqueDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UniqueDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py index e39fa957f0bbb9d3671274d5f58b993e8399814b..340a6ff72e6813c3743d3d83a72ac12d4a392b66 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ZipDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 25e9ea47b82dad479f041a7be37c984f96c95e0e..3c11d7a97fc9a4b2b8b19a8e82ad5e9037d6bbcd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -19,144 +19,32 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib - - -class ShuffleDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_shuffle_dataset( - self, - range_limit=10, - num_repeats=5, - buffer_size=5, - seed=None, - reshuffle_each_iteration=None, - ): - return dataset_ops.Dataset.range(range_limit).shuffle( - buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) - - def testShuffleCore(self): - - seed = 55 - range_limit = 5 - num_repeats = 2 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 5, 8, 10] - # pylint: disable=cell-var-from-loop - # pylint: disable=g-long-lambda - for reshuffle_each_iteration in [True, False]: - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda - - def testNonDeterministicSeeding(self): - - range_limit = 5 - num_repeats = 2 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 5, 8, 10] - for reshuffle_each_iteration in [True, False]: - for buffer_size in buffer_sizes: - - def ds_fn(): - # pylint: disable=cell-var-from-loop - return self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=None, # Iterator seeds are generated non-deterministically. - reshuffle_each_iteration=reshuffle_each_iteration) - # pylint: enable=cell-var-from-loop - - # We checkpoint the initial state of the Dataset so that we can restore - # the seeds in the next run. Since the seeding is non-deterministic - # the dataset gets initialized with different seeds each time. - expected = self.gen_outputs( - ds_fn, - break_points=[0], - num_outputs=num_outputs, - ckpt_saved=False, - verify_exhausted=False, - save_checkpoint_at_end=False) - actual = self.gen_outputs( - ds_fn, - break_points=self.gen_break_points(num_outputs), - num_outputs=num_outputs, - ckpt_saved=True, - verify_exhausted=False) - self.match(expected, actual) - - def testMultipleIterators(self): - range_limit = 5 - num_repeats = 2 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 5, 8, 10] - - for reshuffle_each_iteration in [True, False]: - for buffer_size in buffer_sizes: - - def ds_fn(): - # pylint: disable=cell-var-from-loop - return self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=None, # Iterator seeds are generated non-deterministically. - reshuffle_each_iteration=reshuffle_each_iteration) - # pylint: enable=cell-var-from-loop - - with ops.Graph().as_default() as g: - ds = ds_fn() - iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] - get_next_ops = [it.get_next() for it in iterators] - saveables = [ - contrib_iterator_ops.make_saveable_from_iterator(it) - for it in iterators - ] - for saveable in saveables: - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - saver = saver_lib.Saver(allow_empty=True) - with self.test_session(graph=g) as sess: - self._save(sess, saver) - expected = [sess.run(get_next_ops) for _ in range(num_outputs)] - self._restore(saver, sess) - actual = [sess.run(get_next_ops) for _ in range(num_outputs)] - self.match(expected, actual) - - -class ShuffleAndRepeatTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + + +class ShuffleAndRepeatTest(test.TestCase): def _build_ds(self, seed, count=5, num_elements=20): return dataset_ops.Dataset.range(num_elements).apply( shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): + get_next = ds_fn().make_one_shot_iterator().get_next() + outputs = [] + with self.test_session() as sess: + for _ in range(num_outputs): + outputs.append(sess.run(get_next)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + return outputs + def testCorrectOutput(self): - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertSequenceEqual( sorted(output), sorted( np.array([range(20) for _ in range(5)]).flatten())) @@ -165,53 +53,53 @@ class ShuffleAndRepeatTest( def testReshuffling(self): # Check that the output orders of different epochs are indeed different. - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) for i in range(4): epoch1 = output[i * 20:(i + 1) * 20] epoch2 = output[(i + 1) * 20:(i + 2) * 20] self.assertNotEqual(epoch1, epoch2) def testSameOrderForSameSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertEqual(output1, output2) def testDifferentOrderForDifferentSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(20), 100) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountNone(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=None), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=None), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountMinusOne(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testInfiniteOutputs(self): # Asserting the iterator is exhausted after producing 100 items should fail. with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None), 100) with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1), 100) def testInfiniteEmpty(self): with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), - [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + 100) with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], - 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), + 100) def testLargeBufferSize(self): with ops.Graph().as_default() as g: @@ -222,17 +110,5 @@ class ShuffleAndRepeatTest( sess.run(get_next_op) -class ShuffleAndRepeatSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, seed): - return dataset_ops.Dataset.range(20).apply( - shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) - - def testCore(self): - self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), - 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 33c48e20bea53b88d69a59e715af38b22dd2cbd4..5590a4bf783d12b0d0710c0130b0b1df921c9baa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -58,6 +58,7 @@ class SlideDatasetTest(test.TestCase): [t.shape.as_list() for t in get_next]) with self.test_session() as sess: + # stride < window_size. # Slide over a finite input, where the window_size divides the # total number of elements. sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) @@ -71,11 +72,9 @@ class SlideDatasetTest(test.TestCase): result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Slide over a finite input, where the window_size does not # divide the total number of elements. sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) - num_batches = (20 * 7 - 17) // 9 + 1 for i in range(num_batches): result = sess.run(get_next) @@ -86,6 +85,41 @@ class SlideDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + # stride == window_size. + sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14}) + num_batches = 20 * 7 // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # stride > window_size. + sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14}) + num_batches = 20 * 7 // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(10): + self.assertAllEqual(component[(i*14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + # Drop the last batch which is smaller than window_size. + sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19}) + num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*19 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + # Slide over a finite input, which is less than window_size, # should fail straight away. sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) @@ -108,10 +142,6 @@ class SlideDatasetTest(test.TestCase): # Invalid stride should be an initialization time error. with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3}) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5}) def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 4148addf2878c99f47ebe1454edf69ad7f38dfbc..2c2cfbebff5d3eba00f120467102b4185d81ab24 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -18,83 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -import sqlite3 - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTestBase(test.TestCase): - - def _createSqlDataset(self, output_types, num_repeats=1): - dataset = readers.SqlDataset(self.driver_name, self.data_source_name, - self.query, output_types).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - return init_op, get_next - - def setUp(self): - self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - self.driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - self.query = array_ops.placeholder(dtypes.string, shape=[]) - - conn = sqlite3.connect(self.data_source_name) - c = conn.cursor() - c.execute("DROP TABLE IF EXISTS students") - c.execute("DROP TABLE IF EXISTS people") - c.execute("DROP TABLE IF EXISTS townspeople") - c.execute( - "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " - "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " - "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " - "favorite_big_number INTEGER, favorite_negative_number INTEGER, " - "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " - "account_balance INTEGER, registration_complete INTEGER)") - c.executemany( - "INSERT INTO students (first_name, last_name, motto, school_id, " - "favorite_nonsense_word, desk_number, income, favorite_number, " - "favorite_big_number, favorite_negative_number, " - "favorite_medium_sized_number, brownie_points, account_balance, " - "registration_complete) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, - 9223372036854775807, -2, 32767, 0, 0, 1), - ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, - -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) - c.execute( - "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") - c.executemany( - "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", - [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", - "California")]) - c.execute( - "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " - "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " - "FLOAT, accolades FLOAT, triumphs FLOAT)") - c.executemany( - "INSERT INTO townspeople (first_name, last_name, victories, " - "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", - [("George", "Washington", 20.00, - 1331241.321342132321324589798264627463827647382647382643874, - 9007199254740991.0), - ("John", "Adams", -19.95, - 1331241321342132321324589798264627463827647382647382643874.0, - 9007199254740992.0)]) - conn.commit() - conn.close() - - -class SqlDatasetTest(SqlDatasetTestBase): +class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # Test that SqlDataset can read from a database table. def testReadResultSet(self): @@ -656,27 +586,5 @@ class SqlDatasetTest(SqlDatasetTestBase): sess.run(get_next) -class SqlDatasetSerializationTest( - SqlDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_repeats): - data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - query = ("SELECT first_name, last_name, motto FROM students ORDER BY " - "first_name DESC") - output_types = (dtypes.string, dtypes.string, dtypes.string) - return readers.SqlDataset(driver_name, data_source_name, query, - output_types).repeat(num_repeats) - - def testSQLSaveable(self): - num_repeats = 4 - num_outputs = num_repeats * 2 - self.run_core_tests(lambda: self._build_dataset(num_repeats), - lambda: self._build_dataset(num_repeats // 2), - num_outputs) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5c725a9269e80311f3e73c51c28ab80e7c4815 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing SqlDataset.""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import sqlite3 + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetTestBase(test.TestCase): + """Base class for setting up and testing SqlDataset.""" + + def _createSqlDataset(self, output_types, num_repeats=1): + dataset = readers.SqlDataset(self.driver_name, self.data_source_name, + self.query, output_types).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return init_op, get_next + + def setUp(self): + self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + self.driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + self.query = array_ops.placeholder(dtypes.string, shape=[]) + + conn = sqlite3.connect(self.data_source_name) + c = conn.cursor() + c.execute("DROP TABLE IF EXISTS students") + c.execute("DROP TABLE IF EXISTS people") + c.execute("DROP TABLE IF EXISTS townspeople") + c.execute( + "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " + "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " + "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " + "favorite_big_number INTEGER, favorite_negative_number INTEGER, " + "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " + "account_balance INTEGER, registration_complete INTEGER)") + c.executemany( + "INSERT INTO students (first_name, last_name, motto, school_id, " + "favorite_nonsense_word, desk_number, income, favorite_number, " + "favorite_big_number, favorite_negative_number, " + "favorite_medium_sized_number, brownie_points, account_balance, " + "registration_complete) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, + 9223372036854775807, -2, 32767, 0, 0, 1), + ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, + -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) + c.execute( + "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") + c.executemany( + "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", + [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", + "California")]) + c.execute( + "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " + "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " + "FLOAT, accolades FLOAT, triumphs FLOAT)") + c.executemany( + "INSERT INTO townspeople (first_name, last_name, victories, " + "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", + [("George", "Washington", 20.00, + 1331241.321342132321324589798264627463827647382647382643874, + 9007199254740991.0), + ("John", "Adams", -19.95, + 1331241321342132321324589798264627463827647382647382643874.0, + 9007199254740992.0)]) + conn.commit() + conn.close() + + diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 17b6644759e53f84b23e070a71267aa15dcffe49..b4945685c1d1062bf416b73f1541f351adf45604 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -19,7 +19,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 @@ -236,68 +235,5 @@ class FeatureStatsDatasetTest( self._sum_keywords(1) * num_epochs + 2 * total_records) -class StatsDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset_bytes_stats(self, num_elements): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( - stats_ops.bytes_produced_stats("bytes_produced")) - - def test_bytes_produced_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.bytes_produced_stats(["bytes_produced"])), - None, 100) - - def testBytesStatsDatasetSaveableCore(self): - num_outputs = 100 - self.run_core_tests( - lambda: self._build_dataset_bytes_stats(num_outputs), - lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) - - def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag)) - - def _build_dataset_multiple_tags(self, - num_elements, - tag1="record_latency", - tag2="record_latency_2"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) - - def test_latency_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats(["record_latency", "record_latency_2"])), - None, 100) - - def testLatencyStatsDatasetSaveableCore(self): - num_outputs = 100 - - self.run_core_tests( - lambda: self._build_dataset_latency_stats(num_outputs), - lambda: self._build_dataset_latency_stats(num_outputs // 10), - num_outputs) - - self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), - None, num_outputs) - - tag1 = "record_latency" - tag2 = "record_latency" - self.run_core_tests( - lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), - None, num_outputs) - - -# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the -# transformation `stats_ops.set_stats_aggregator`, since we don't support -# serializing StatsAggregator yet. - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 9167cb3379bba5cb1ba76a96549395c45dca9e35..0486e2bce20e9dcf81dcb5ac49fe5b397e44bf0c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import threading +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import threadpool @@ -30,9 +31,11 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class OverrideThreadpoolDatasetTest(test.TestCase): +class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): - def testNumThreads(self): + @parameterized.parameters((1, None), (2, None), (4, None), (8, None), + (16, None), (4, -1), (4, 0), (4, 1), (4, 4)) + def testNumThreads(self, num_threads, max_intra_op_parallelism): def get_thread_id(_): # Python creates a dummy thread object to represent the current @@ -42,35 +45,35 @@ class OverrideThreadpoolDatasetTest(test.TestCase): # identifier that maps one-to-one with the underlying OS thread. return np.array(threading.current_thread().ident).astype(np.int64) - for num_threads in [1, 2, 4, 8, 16]: + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) - dataset = ( - dataset_ops.Dataset.range(1000).map( - lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), - num_parallel_calls=32).apply(unique.unique())) + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name="private_thread_pool_%d" % num_threads)) - dataset = threadpool.override_threadpool( - dataset, - threadpool.PrivateThreadPool( - num_threads, display_name="private_thread_pool_%d" % num_threads)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - thread_ids = [] - try: - while True: - thread_ids.append(sess.run(next_element)) - except errors.OutOfRangeError: - pass - self.assertEqual(len(thread_ids), len(set(thread_ids))) - self.assertGreater(len(thread_ids), 0) - # NOTE(mrry): We don't control the thread pool scheduling, and - # so cannot guarantee that all of the threads in the pool will - # perform work. - self.assertLessEqual(len(thread_ids), num_threads) + with self.test_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index 3c436f7a0b45a13109960e87dd97ca56b10bb871..d79a842e7a5d816e2e6a52fc83acbd6b260cf64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import unique from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -79,18 +78,5 @@ class UniqueDatasetTest(test.TestCase): ]) -class UniqueSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testUnique(self): - - def build_dataset(num_elements, unique_elem_range): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: x % unique_elem_range).apply(unique.unique()) - - self.run_core_tests(lambda: build_dataset(200, 100), - lambda: build_dataset(40, 100), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 33b7a75046cf2acfa3d787833b907aa2b28dbdca..02408145625b7e751541e7b87dc4fd5da4f7cad9 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -49,26 +49,6 @@ py_library( ], ) -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":iterator_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - ], -) - py_library( name = "random_ops", srcs = [ diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 052618e08c8f204613db5a20d42e078f17f12840..7350d595f5f6b64d062dcc5ebc69d7e85d3f7b22 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -77,17 +77,17 @@ def dense_to_sparse_batch(batch_size, row_shape): """ def _apply_fn(dataset): - return DenseToSparseBatchDataset(dataset, batch_size, row_shape) + return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) return _apply_fn -class UnbatchDataset(dataset_ops.Dataset): +class _UnbatchDataset(dataset_ops.Dataset): """A dataset that splits the elements of its input into multiple elements.""" def __init__(self, input_dataset): """See `unbatch()` for more details.""" - super(UnbatchDataset, self).__init__() + super(_UnbatchDataset, self).__init__() flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") @@ -144,7 +144,7 @@ def unbatch(): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" if not sparse.any_sparse(dataset.output_classes): - return UnbatchDataset(dataset) + return _UnbatchDataset(dataset) # NOTE(mrry): We must ensure that any SparseTensors in `dataset` # are normalized to the rank-1 dense representation, so that the @@ -170,12 +170,12 @@ def unbatch(): dataset.output_shapes, dataset.output_classes, allow_unsafe_cast=True) - return UnbatchDataset(restructured_dataset) + return _UnbatchDataset(restructured_dataset) return _apply_fn -def filter_irregular_batches(batch_size): +def _filter_irregular_batches(batch_size): """Transformation that filters out batches that are not of size batch_size.""" def _apply_fn(dataset): @@ -254,7 +254,7 @@ def batch_and_drop_remainder(batch_size): # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time # after 6/30/2018. batched = dataset.batch(batch_size) - return filter_irregular_batches(batch_size)(batched) + return _filter_irregular_batches(batch_size)(batched) return _apply_fn @@ -293,17 +293,17 @@ def padded_batch_and_drop_remainder(batch_size, # any time after 6/30/2018. batched = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, padding_values=padding_values) - return filter_irregular_batches(batch_size)(batched) + return _filter_irregular_batches(batch_size)(batched) return _apply_fn -class DenseToSparseBatchDataset(dataset_ops.Dataset): +class _DenseToSparseBatchDataset(dataset_ops.Dataset): """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(DenseToSparseBatchDataset, self).__init__() + super(_DenseToSparseBatchDataset, self).__init__() if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 5f5513849cb29a18b86ba8bcee1ab6c9c60674cb..d46d96c461ad4cc0ac25a8ddc285cec23d09c682 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -46,17 +46,17 @@ def ignore_errors(): """ def _apply_fn(dataset): - return IgnoreErrorsDataset(dataset) + return _IgnoreErrorsDataset(dataset) return _apply_fn -class IgnoreErrorsDataset(dataset_ops.Dataset): +class _IgnoreErrorsDataset(dataset_ops.Dataset): """A `Dataset` that silently ignores errors when computing its input.""" def __init__(self, input_dataset): """See `Dataset.ignore_errors()` for details.""" - super(IgnoreErrorsDataset, self).__init__() + super(_IgnoreErrorsDataset, self).__init__() self._input_dataset = input_dataset def _as_variant_tensor(self): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 4068a2ffa5ab877c372a6f32e3430812aa138391..ca9540bf136a5028c4321319bdfacaf8a16484c7 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -55,7 +55,7 @@ def group_by_reducer(key_func, reducer): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByReducerDataset(dataset, key_func, reducer) + return _GroupByReducerDataset(dataset, key_func, reducer) return _apply_fn @@ -113,8 +113,8 @@ def group_by_window(key_func, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) + return _GroupByWindowDataset(dataset, key_func, reduce_func, + window_size_func) return _apply_fn @@ -227,39 +227,12 @@ def bucket_by_sequence_length(element_length_func, return _apply_fn -class _VariantDataset(dataset_ops.Dataset): - """A Dataset wrapper for a tf.variant-typed function argument.""" - - def __init__(self, dataset_variant, output_types, output_shapes, - output_classes): - super(_VariantDataset, self).__init__() - self._dataset_variant = dataset_variant - self._output_types = output_types - self._output_shapes = output_shapes - self._output_classes = output_classes - - def _as_variant_tensor(self): - return self._dataset_variant - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - -class GroupByReducerDataset(dataset_ops.Dataset): +class _GroupByReducerDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a reduction.""" def __init__(self, input_dataset, key_func, reducer): """See `group_by_reducer()` for details.""" - super(GroupByReducerDataset, self).__init__() + super(_GroupByReducerDataset, self).__init__() self._input_dataset = input_dataset @@ -388,12 +361,12 @@ class GroupByReducerDataset(dataset_ops.Dataset): **dataset_ops.flat_structure(self)) -class GroupByWindowDataset(dataset_ops.Dataset): +class _GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" def __init__(self, input_dataset, key_func, reduce_func, window_size_func): """See `group_by_window()` for details.""" - super(GroupByWindowDataset, self).__init__() + super(_GroupByWindowDataset, self).__init__() self._input_dataset = input_dataset @@ -431,24 +404,19 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" - def reduce_func_wrapper(key, window_dataset_variant): - """Wrapper that converts between tf.variant and Dataset objects.""" - window_dataset = _VariantDataset( - window_dataset_variant, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - output_dataset = reduce_func(key, window_dataset) - if not isinstance(output_dataset, dataset_ops.Dataset): - raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_classes = output_dataset.output_classes - self._output_types = output_dataset.output_types - self._output_shapes = output_dataset.output_shapes - return output_dataset._as_variant_tensor() # pylint: disable=protected-access - + nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access wrapped_func = dataset_ops.StructuredFunctionWrapper( - reduce_func_wrapper, "tf.contrib.data.reduce_by_window()", - input_classes=(ops.Tensor, ops.Tensor), - input_shapes=(tensor_shape.scalar(), tensor_shape.scalar()), - input_types=(dtypes.int64, dtypes.variant)) + reduce_func, "tf.contrib.data.reduce_by_window()", + input_classes=(ops.Tensor, nested_dataset), + input_shapes=(tensor_shape.scalar(), nested_dataset), + input_types=(dtypes.int64, nested_dataset), + experimental_nested_dataset_support=True) + if not isinstance( + wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access + raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_classes = wrapped_func.output_classes.output_classes + self._output_types = wrapped_func.output_types.output_types + self._output_shapes = wrapped_func.output_shapes.output_shapes self._reduce_func = wrapped_func.function @property diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 70153ac575758f16beff373941dfefb32bd342cf..bcc959594a6b311a3c60bb4696ac97be5c448756 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -153,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): return _apply_fn -class DirectedInterleaveDataset(dataset_ops.Dataset): +class _DirectedInterleaveDataset(dataset_ops.Dataset): """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" def __init__(self, selector_input, data_inputs): @@ -236,7 +236,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): selector_input = dataset_ops.Dataset.zip( (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) - return DirectedInterleaveDataset(selector_input, datasets) + return _DirectedInterleaveDataset(selector_input, datasets) def choose_from_datasets(datasets, choice_dataset): @@ -280,4 +280,4 @@ def choose_from_datasets(datasets, choice_dataset): and choice_dataset.output_classes == ops.Tensor): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") - return DirectedInterleaveDataset(choice_dataset, datasets) + return _DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 2ca3805d6609a82aa733da36d84c7fb58921d764..cf896572262929add5ac34d4fc8e4192c1049da3 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -39,17 +39,17 @@ def optimize(optimizations=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return OptimizeDataset(dataset, optimizations) + return _OptimizeDataset(dataset, optimizations) return _apply_fn -class OptimizeDataset(dataset_ops.Dataset): +class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" def __init__(self, input_dataset, optimizations): """See `optimize()` for details.""" - super(OptimizeDataset, self).__init__() + super(_OptimizeDataset, self).__init__() self._input_dataset = input_dataset if optimizations is None: optimizations = [] diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index e4c9f8b58a2a4390004b0ad318163526b443d44f..21fc17102e16a1f98f2c2e8aa0aeec89989edf67 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -32,15 +32,32 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops -# TODO(rohanj): Add a python class that constructs resource in the __init__ -# method and provides a get_next() that calls the prefetch op. def function_buffering_resource(string_arg, target_device, f, buffer_size, + output_types, container="", shared_name=None, name=None): + """Creates a FunctionBufferingResource. + + A FunctionBufferingResource fills up a buffer by calling a function `f` on + `target_device`. `f` should take in only a single string argument as input. + + Args: + string_arg: The single string argument to the function. + target_device: The device to run `f` on. + f: The function to be executed. + buffer_size: Size of the buffer to be populated. + output_types: The output types generated by the function. + container: (Optional) string. Defaults to "". + shared_name: (Optional) string. + name: (Optional) string to name the op. + + Returns: + Handle to a FunctionBufferingResource. + """ if shared_name is None: shared_name = "" return gen_dataset_ops.function_buffering_resource( @@ -50,7 +67,8 @@ def function_buffering_resource(string_arg, f=f, buffer_size=buffer_size, container=container, - name=name) + name=name, + output_types=output_types) def function_buffering_resource_get_next(function_buffer_resource, @@ -123,7 +141,10 @@ class _PrefetchToDeviceIterator(object): target_device=iterator_device, string_arg=input_iterator_handle, buffer_size=buffer_size, - shared_name=shared_name) + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes))) if not self._one_shot: reset_op = function_buffering_resource_reset(self._buffering_resource) @@ -212,6 +233,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): with ops.device(device): self._buffering_resource = function_buffering_resource( f=_prefetch_fn, + output_types=self._flat_output_types, target_device=gen_dataset_ops.iterator_get_device(self._resource), string_arg=input_iterator_handle, buffer_size=buffer_size, diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index f935beb1a9e85d4901857e7781a5ed8473838fa5..3f3c5ca17cf6ae22a719ed1d593d98eec37413fb 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -86,7 +86,7 @@ def sliding_window_batch(window_size, stride=1): elements in the sliding window. stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the steps moving the sliding window forward for one iteration. The default - is `1`. It must be in `[1, window_size)`. + is `1`. It must be positive. Returns: A `Dataset` transformation function, which can be passed to diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3c82a03df1745d855b2d3f918f7bbde113600556..97931f75bd37d9e45864fe477c6e1620b5e4f193 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -23,6 +23,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class StatsAggregator(object): """A stateful resource that aggregates statistics from one or more iterators. @@ -110,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return self._input_dataset.output_classes -# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def set_stats_aggregator(stats_aggregator): """Set the given stats_aggregator for aggregating the input dataset stats. @@ -128,6 +131,8 @@ def set_stats_aggregator(stats_aggregator): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. @@ -150,6 +155,8 @@ def bytes_produced_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def latency_stats(tag): """Records the latency of producing each element of the input dataset. @@ -171,6 +178,8 @@ def latency_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def feature_stats(tag): """Records the features stats from `Example` records of the input dataset. diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index bb49604d4de90d726418684124608438aa33e6cf..9af1e784ffb4f6d71da25f09d60343b649c5079b 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -37,22 +37,28 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class PrivateThreadPool(object): """A stateful resource that represents a private thread pool.""" - def __init__(self, num_threads, display_name=None): + def __init__(self, num_threads, display_name=None, + max_intra_op_parallelism=1): """Creates a `PrivateThreadPool` with the given number of threads.""" if context.executing_eagerly(): shared_name = _generate_shared_name("privatethreadpool") self._resource = gen_dataset_ops.thread_pool_handle( num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, display_name=display_name, shared_name=shared_name) self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device=context.context().device_name) else: self._resource = gen_dataset_ops.thread_pool_handle( - num_threads=num_threads, display_name=display_name) + num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name=display_name) class _ThreadPoolDataset(dataset_ops.Dataset): @@ -82,6 +88,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return self._input_dataset.output_classes +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def override_threadpool(dataset, thread_pool): """Returns a new dataset that uses the given thread pool for its operations. diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 4ce6ddede8350735636fd152fdc9df0319265990..e0ce0a4ef15f6b9181bce92fb4d73bf1fab2e66c 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -42,17 +42,17 @@ def unique(): """ def _apply_fn(dataset): - return UniqueDataset(dataset) + return _UniqueDataset(dataset) return _apply_fn -class UniqueDataset(dataset_ops.Dataset): +class _UniqueDataset(dataset_ops.Dataset): """A `Dataset` contains the unique elements from its input.""" def __init__(self, input_dataset): """See `unique()` for details.""" - super(UniqueDataset, self).__init__() + super(_UniqueDataset, self).__init__() self._input_dataset = input_dataset if input_dataset.output_types not in (dtypes.int32, dtypes.int64, dtypes.string): diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 9dfb8552f1b0f058b44f8ed09c2ed681367293d5..eba0dd0ea330e29db0ea8e68ee14767fcb8ddad0 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -587,7 +587,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "noguitar", "notsan", ], ) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index ba03b14deb9a3897dae29382ce601c0319f84735..9a8ea4aa48b8cf4c5906f18d8bddacc224e0b644 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -321,10 +321,6 @@ default_strategy = NamedDistribution( one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) -tpu_strategy_single_iteration = NamedDistribution( - "TPUSingleIteration", - lambda: tpu_lib.TPUStrategy(iterations_per_step=1), - required_tpu=True) tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index f8ae8b9712c392fa948c8598dd123cdea01d9866..0261ce43fa854d3b2ee38df19b8a8938cac3c8f3 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util -def _validate_destinations(destinations): +def validate_destinations(destinations): if not isinstance(destinations, (value_lib.DistributedValues, six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," @@ -55,7 +55,7 @@ def _validate_value_destination_pairs(value_destination_pairs): # TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. -def _get_devices_from(destinations): +def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) elif isinstance(destinations, six.string_types): @@ -65,7 +65,7 @@ def _get_devices_from(destinations): def _devices_match(left, right): - return set(_get_devices_from(left)) == set(_get_devices_from(right)) + return set(get_devices_from(left)) == set(get_devices_from(right)) def _all_devices_match(value_destination_pairs): @@ -80,7 +80,7 @@ def _all_devices_match(value_destination_pairs): def _simple_broadcast(value, destinations): index = {} - devices = _get_devices_from(destinations) + devices = get_devices_from(destinations) for d in devices: index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( value, d) @@ -146,7 +146,7 @@ class CrossTowerOps(object): if not isinstance(per_device_value, value_lib.PerDevice): raise ValueError("`per_device_value` must be a `PerDevice` object.") if destinations is not None: - _validate_destinations(destinations) + validate_destinations(destinations) return self._reduce(method_string, per_device_value, destinations) def batch_reduce(self, method_string, value_destination_pairs): @@ -173,7 +173,7 @@ class CrossTowerOps(object): "tuples of PerDevice objects and destinations") for _, d in value_destination_pairs: if d is not None: - _validate_destinations(d) + validate_destinations(d) return self._batch_reduce(method_string, value_destination_pairs) @@ -187,7 +187,7 @@ class CrossTowerOps(object): Returns: a Mirrored object. """ - _validate_destinations(destinations) + validate_destinations(destinations) return self._broadcast(tensor, destinations) def _reduce(self, method_string, per_device_value, destinations): @@ -221,7 +221,7 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): super(ReductionToOneDeviceCrossTowerOps, self).__init__() def _reduce(self, method_string, per_device_value, destinations): - devices = _get_devices_from(destinations or per_device_value) + devices = get_devices_from(destinations or per_device_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, self.accumulation_fn, method_string) @@ -501,7 +501,7 @@ class AllReduceCrossTowerOps(CrossTowerOps): logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) - devices = _get_devices_from(destinations or per_device_value) + devices = get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, math_ops.add_n, method_string) @@ -536,7 +536,7 @@ class AllReduceCrossTowerOps(CrossTowerOps): destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) - device_grad_packs, self._tensor_packer = _pack_tensors( + device_grad_packs, tensor_packer = _pack_tensors( grouped, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group) @@ -554,7 +554,7 @@ class AllReduceCrossTowerOps(CrossTowerOps): cross_tower_utils.aggregate_gradients_using_hierarchical_copy( destinations, device_grad_packs)) - reduced = _unpack_tensors(reduced, self._tensor_packer) + reduced = _unpack_tensors(reduced, tensor_packer) return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, method_string) @@ -665,13 +665,13 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( spec_tuple.limit, remaining_grads) if this_grads: - device_grad_packs, self._tensor_packer = _pack_tensors( + device_grad_packs, tensor_packer = _pack_tensors( this_grads, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group) range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( self._worker_devices, device_grad_packs, len(self._worker_devices), spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) - range_agg_grads = _unpack_tensors(range_agg_grads, self._tensor_packer) + range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) if not aggregated_grads: aggregated_grads = range_agg_grads diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index fed5505d92ef2544215069736c166a67d6141708..c540ea0d232e31af51ef4c2a1530250669e49495 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -36,7 +36,7 @@ from tensorflow.python.training import device_util def _make_per_device(values, devices): - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) assert len(values) == len(devices) index = {} for d, v in zip(devices, values): @@ -53,7 +53,7 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))}) @@ -93,7 +93,7 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): self._assert_values_equal(l, r) else: self.assertEqual(type(left), type(right)) - self.assertEqual(left.devices, right.devices) + self.assertEqual(set(left.devices), set(right.devices)) if isinstance(list(left._index.values())[0], ops.IndexedSlices): for (d, v) in left._index.items(): self._assert_indexed_slices_equal(v, right._index[d]) diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py index 4ef8db681503dcef8c72f641455dbb999cef05cf..d25964fa41adc7b1c9164a4ffe49c4c5532f76ac 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -38,7 +38,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.evaluate(ops.convert_to_tensor(left)), self.evaluate(ops.convert_to_tensor(right))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAggregateTensors(self): t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) @@ -46,7 +46,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self._assert_values_equal(total, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAggregateIndexedSlices(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -57,7 +57,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(total, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDivideTensor(self): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) n = 2 @@ -65,7 +65,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) self._assert_values_equal(expected, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDivideIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -75,13 +75,13 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(expected, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_List(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -89,7 +89,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_Tuple(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -97,7 +97,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_PerDevice(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -106,7 +106,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_PerDeviceMapOutput(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 5c056a7c73def2f1fb4bbe0df4d3f82fdabda3df..aeeb9553e6044a0a928936597400e582e0329b95 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -56,6 +56,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): is_tpu=[True])) def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) @@ -84,8 +88,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) if is_tpu: with self.test_session() as sess: @@ -111,6 +115,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): is_tpu=[True])) def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + created_variables = [] trainable_variables = [] @@ -186,7 +194,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # towers will re-execute UPDATE_OPS of previous towers. update_ops_in_cross_tower_mode=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], + distribution=[combinations.tpu_strategy], optimizer_fn=[ combinations.gradient_descent_optimizer_v1_fn, combinations.gradient_descent_optimizer_v2_fn @@ -198,6 +206,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm, is_tpu, update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): num_towers = len(distribution.worker_devices) model_fn, dataset_fn, batchnorm = batchnorm_example( @@ -242,7 +254,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all towers. @@ -279,12 +291,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], + distribution=[combinations.tpu_strategy], is_tpu=[True], mode=["graph"], use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, use_callable_loss, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): all_vars = [] @@ -329,7 +345,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) - weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + weight = numpy.squeeze(self.evaluate(v)) # Our model is: # predict = x * w # loss = (predict - y)^2 diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 900aa10e93e8881aa236bac8a2873d5c5531c6f6..d269bed1e573fdb4b4ef8febd07ff882e3b82594 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -31,7 +31,6 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import coordinator from tensorflow.python.training import device_util @@ -109,6 +108,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if tower_local is not None: kwargs["trainable"] = False + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. @@ -283,8 +285,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. index = {} - i = 0 - for m in map_over: + for i, m in enumerate(map_over): d = self._devices[i % len(self._devices)] with ops.device(d): l = index.get(d, []) @@ -308,9 +309,29 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._cross_tower_ops def _reduce(self, method_string, value, destinations): - if len(self._devices) == 1 and not isinstance(value, values.PerDevice): - value = values.PerDevice({self._devices[0]: value}) - assert isinstance(value, values.PerDevice) + assert not isinstance(value, values.Mirrored) + if not isinstance(value, values.PerDevice): + if value == 0: + return 0 + if method_string == "mean": + return self._broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + if len(self._devices) == 1: + if destinations: + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + raise ValueError("A non PerDevice value cannot be reduced with the given " + "method_string.") return self._get_cross_tower_ops().reduce( method_string, value, destinations=destinations) @@ -320,14 +341,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): value_destination_pairs) def _update(self, var, fn, *args, **kwargs): - # TODO(josh11b): Also support TowerLocalVariables here? If so, args and - # kwargs don't need to be mirrored. - assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. + assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) @@ -347,37 +367,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" if isinstance(tower_local_var, values.TowerLocalVariable): - return math_ops.add_n(self.unwrap(tower_local_var)) + return tower_local_var._get_cross_tower() # pylint: disable=protected-access assert isinstance(tower_local_var, values.Mirrored) return array_ops.identity(tower_local_var.get()) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - if isinstance(val, values.TowerLocalVariable): - val = self.reduce(val.reduce_method, val, destinations=destination) - with ops.device(destination): - return fn(self.unwrap(val)[0]) - - assert isinstance(val, values.Mirrored), ( - "val = %s (type %s)" % (val, val.__class__.__name__)) - if val.on_device(destination): - with ops.device(destination): - # Use an identity here to make sure we are returning a tensor - # instead of e.g. a variable object. - return array_ops.identity(fn(val.get(destination))) - device = None - for d in self._devices: - if val.on_device(d): - device = d - break - assert device is not None, ( - "Could not find destination %s in list of devices %s." % - (destination, val.devices)) - with ops.device(device): - v = fn(val.get(device)) - with ops.device(destination): - return array_ops.identity(v) - def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index bccd278847e3c87080af3cb15665e7a0d802d8fb..8d474124b7e0a80d49ed646254269988f49d69e4 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -32,12 +32,14 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib + GPU_TEST = "test_gpu" in sys.argv[0] @@ -83,13 +85,13 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self.skipTest("Not GPU test") self.assertEqual(2, self._get_distribution_strategy().num_towers) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): if not GPU_TEST: self.skipTest("Not GPU test") self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRunRegroupError(self): def run_fn(device_id): @@ -101,7 +103,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): with dist.scope(), self.assertRaises(AssertionError): dist.call_for_each_tower(run_fn, dist.worker_device_index) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceToCpu(self): if not GPU_TEST: self.skipTest("Not GPU test") @@ -118,6 +120,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): expected = sum(range(len(dist.worker_devices))) self.assertEqual(expected, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes() + def testReduceToMultipleDestinations(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + devices = ["/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0", + "/device:GPU:0"]) + unwrapped = dist.unwrap(reduced) + self.assertEqual(2, len(unwrapped)) + self.assertEqual(1.0, self.evaluate(unwrapped[0])) + class MirroredStrategyVariableCreationTest(test.TestCase): @@ -337,6 +357,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): all_v_sum = {} all_v_mean = {} + components_sum = {} + components_mean = {} def model_fn(device_id): tower_context = distribute_lib.get_tower_context() @@ -350,21 +372,33 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v_mean.assign(6.0 * device_id)] all_v_sum[device_id] = v_sum all_v_mean[device_id] = v_mean - return updates, v_sum, v_mean + c_sum = v_sum.get() + c_mean = v_mean.get() + components_sum[device_id] = c_sum + components_mean[device_id] = c_mean + self.assertIsNot(v_sum, c_sum) + self.assertIsNot(v_mean, c_mean) + return updates, v_sum, v_mean, c_sum, c_mean dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): # Create "sum" and "mean" versions of TowerLocalVariables. - ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( - model_fn, dist.worker_device_index, run_concurrently=False) + ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( + dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False)) # Should see the same wrapping instance in all towers. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) - for i in range(1, dist.num_towers): - self.assertIs(all_v_sum[0], all_v_sum[1]) - self.assertIs(all_v_mean[0], all_v_mean[1]) + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Regroup should recover the same wrapper. + self.assertIs(ret_v_sum, regrouped_sum) + self.assertIs(ret_v_mean, regrouped_mean) + self.assertIsNot(components_sum[0], components_sum[1]) + self.assertIsNot(components_mean[0], components_mean[1]) # Apply updates self.evaluate(variables.global_variables_initializer()) @@ -385,14 +419,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Without get(device), should return the value you get by # applying the reduction across all towers (whether you use - # fetch(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + # read_var(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) - if not context.executing_eagerly(): - self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) - self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. @@ -530,6 +563,239 @@ class MirroredStrategyVariableCreationTest(test.TestCase): _, v1 = dist.unwrap(v) self.assertStartsWith(v1.name, "tower_1/") + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariableUpdate(self): + with context.graph_mode(): + + def model_fn(): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope("sum"): + v_sum = variable_scope.variable(1.0) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + + def update(var, value): + return var.assign(value) + + with dist.scope(): + ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) + update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0)) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values before running the update ops. + self.assertEquals(1.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + self.evaluate(update_ops) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values after running the update ops. + self.assertEquals(5.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(10.0, self.evaluate(ret_v_sum)) + + +class MirroredVariableUpdateTest(test.TestCase): + # The following tests check assign, assign_add and assign_sub on Mirrored + # variables in tower and cross tower context. + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithoutAggregationType(self): + # Test that we always have an aggregation type set on the mirrored variable + # if we assign to it in tower mode. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "You must specify an aggregation method to update a " + "MirroredVariable in Tower Context."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSum(self): + # Test that we don't reduce a non-per-device value with the "sum" + # aggregation type. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "sum" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "A non PerDevice value cannot be reduced with the given " + "method_string."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) + self.assertEquals(6.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "mean" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(0.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0)) + self.assertEquals(7.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "mean" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_add(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(1.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(5.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) + self.assertEquals(3.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(5.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the + # aggregation method. + mirrored_var._aggregation_method = "mean" + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast(distribute_lib.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_sub(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.5, self.evaluate(mirrored_var)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 61cbe6df813bb28bf8baa83d9e28ffafc4f0cbb8..a066adf1246ecd9ab8bd6a85be1f1e9be2c35b17 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -47,7 +47,7 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def testTowerId(self): self._test_tower_id(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 4fdb9bf69b4f6ad76b79fd298f5303f24a1bd455..2892ce439494320a115b8eae0025a132841c4a8f 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -52,11 +52,11 @@ class MonitorTest(test.TestCase, parameterized.TestCase): self.assertEqual(1, len(layer.trainable_variables)) mirrored_weight_variable = layer.trainable_variables[0] - start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = self.evaluate(mirrored_weight_variable) start_error = abs(numpy.array(start_error) - 1) monitor.run_steps(9) - end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = self.evaluate(mirrored_weight_variable) end_error = abs(numpy.array(end_error) - 1) self.assertGreaterEqual(start_error, end_error) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 7f4bab9d93814eb70a2a1586fc291a16b2766b90..a580dac96c5e6c6c8790aa6af7309988bf7a6477 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -106,13 +106,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): """Read the aggregate value of a tower-local variable.""" return array_ops.identity(tower_local_var) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - with ops.device(self._device): - v = fn(val) - with ops.device(destination): - return array_ops.identity(v) - def _unwrap(self, value): return [value] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 7aad8a953cbedd30b48739416e74b3dc164dc4cd..4fdc0f72e6745b7ef25c591157955f214e0b2c79 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -44,7 +44,7 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testTowerId(self): self._test_tower_id(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index abd3a65ac4e19ece6b69b9834f4218fde55b60c2..a2d736e42271ab1627240949b99088ed3f0746f6 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -59,8 +59,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py index 7b3670b45aba801cf8c18e04bfea03e23eb67184..24cdc627a35f4455cb92484566dc13fa1bbaf2cc 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -89,6 +89,9 @@ class _PrefetchToDeviceIterator(object): with ops.device(device): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_prefetch_fn, + output_types=data_nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)), target_device=target_device, string_arg=input_iterator_handle, buffer_size=buffer_size, diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py index a0b452fc2d445d1cf7dbf5e8fe0e29edef516207..2a9ab51fcfd29a8ae5b37b5c513415af29b277dc 100644 --- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -46,7 +46,7 @@ class CanonicalizeVariableNameTest(test.TestCase): class SharedVariableCreatorTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSharedVariable(self): shared_variable_store = {} diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 75c5ec9659d193e77d219ba79977615d58841d64..2ee94d8f70868c07ca217dd4d433585458efa8d8 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,8 +50,8 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 2b4ad9f146bc1d6a987fbeecbb05122946137154..d2fe8b3b1efabf7b35c070a82d01595f3fa51bf9 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -106,13 +106,13 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): g = d.reduce("sum", g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list for i in range(10): @@ -159,12 +159,12 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): g = d.reduce("sum", g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -184,7 +184,7 @@ class DistributionTestBase(test.TestCase): with d.scope(): map_in = [constant_op.constant(i) for i in range(10)] map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.fetch(d.reduce("sum", map_out)) + observed = d.reduce("sum", map_out) expected = 90 # 2 * (0 + 1 + ... + 9) self.assertEqual(expected, observed.numpy()) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 75441786a615fc0d87b4c4b0b45b9384d678c1d3..1ae12ae98aaecbb0ce46a944d8e61e051627ff51 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,14 +21,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - from tensorflow.contrib import tpu from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.util import nest @@ -36,85 +35,107 @@ from tensorflow.python.util import nest class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, - num_cores_per_host=2, - iterations_per_step=2): + def __init__(self, num_cores_per_host=2): # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. super(TPUStrategy, self).__init__('/cpu:0') # TODO(isaprykin): Auto-detect number of cores and hosts. self._num_cores_per_host = num_cores_per_host - # TODO(isaprykin): This might have to be per-call. - self._iterations_per_step = iterations_per_step + # TODO(priyag): This should not be hardcoded here. + self._host = '/task:0/device:CPU:0' def distribute_dataset(self, dataset_fn): - return values.PerIterationDataset( - self._call_dataset_fn(dataset_fn), self._iterations_per_step, - self._num_cores_per_host) - - def _call_for_each_tower(self, fn, *args, **kwargs): - kwargs.pop('run_concurrently', None) - - inputs = {'args': args, 'kwargs': kwargs} - flat_inputs = nest.flatten(inputs) - - feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] - - feeds = lambda: itertools.compress(flat_inputs, feed_mask) - shapes = [f.get_shape() for f in feeds()] + # TODO(priyag): Perhaps distribute across cores here. + return self._call_dataset_fn(dataset_fn) + + # TODO(priyag): Deal with OutOfRange errors. + # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have + # a mechanism to infer the outputs of `fn`. Pending b/110550782. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + # Enqueue ops + shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') - types = [f.get_dtype() for f in feeds()] - - def infeed_input(i): - """Get input, split it and then enqueue.""" - iteration_inputs = [f.get(i) for f in feeds()] - infeed_inputs = [[inputs_per_core[core_id] - for inputs_per_core in iteration_inputs] - for core_id in range(self._num_cores_per_host)] - - infeed_ops = [] - for core_id, infeed_input in enumerate(infeed_inputs): - infeed_ops.append( + types = nest.flatten(iterator.output_types) + + def enqueue_ops_fn(): + """Enqueue ops for one iteration.""" + control_deps = [] + sharded_inputs = [] + with ops.device(self._host): + for _ in range(self._num_cores_per_host): + # Use control dependencies to ensure a deterministic ordering. + with ops.control_dependencies(control_deps): + inputs = nest.flatten(iterator.get_next()) + control_deps.extend(inputs) + sharded_inputs.append(inputs) + + enqueue_ops = [] + for core_id, shard_input in enumerate(sharded_inputs): + enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( - inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) + inputs=shard_input, shapes=shapes, device_ordinal=core_id)) + return enqueue_ops - with ops.control_dependencies(infeed_ops): + def enqueue_ops_loop_body(i): + with ops.control_dependencies(enqueue_ops_fn()): return i + 1 - with ops.device('/task:0/device:CPU:0'): + with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( - lambda i: i < self._iterations_per_step, - infeed_input, [constant_op.constant(0)], + lambda i: i < iterations, + enqueue_ops_loop_body, + [constant_op.constant(0)], parallel_iterations=1) - def dequeueing_fn(*args, **kwargs): - """Dequeue input arguments and supply them to `fn`.""" - del args, kwargs + # Dequeue ops + def dequeue_fn(): dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - dequeued = iter(dequeued) + return nest.pack_sequence_as(iterator.output_shapes, dequeued) - fn_inputs = [] - for inp, is_feed in zip(flat_inputs, feed_mask): - if is_feed: - fn_inputs.append(next(dequeued)) - else: - fn_inputs.append(inp) + # Wrap `fn` for repeat. + if initial_loop_values is None: + initial_loop_values = [] + ctx = values.MultiStepContext(initial_loop_values) + def run_fn(*args, **kwargs): + del args, kwargs + fn_result = fn(ctx, dequeue_fn()) + if ctx.last_step_outputs is None: + ctx.last_step_outputs = [] + with ops.control_dependencies([fn_result]): + return array_ops.identity(ctx.last_step_outputs) + + # Repeat + # TODO(sourabhbajaj): The input to while loop should be based on the output + # type of the step_fn + def iterate_on_tpu(): + return tpu.repeat(iterations, run_fn, [initial_loop_values]) - fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) - return fn(*fn_inputs['args'], **fn_inputs['kwargs']) + # Re-write and distribute computation. + # TODO(sourabhbajaj): Convert the output to PerDevice variable and + # implement support for that in reduce. + last_step_tensor_outputs = tpu.batch_parallel( + iterate_on_tpu, [], num_shards=self._num_cores_per_host) - def iterate_on_tpu(): - return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) + # Take index [0] of last_step_tensor_outputs as we wrapped + # initial_loop_values in a list in the `repeat` call. + return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), + last_step_tensor_outputs[0], ctx) + def _call_for_each_tower(self, fn, *args, **kwargs): + kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access - tpu_result = tpu.batch_parallel( - iterate_on_tpu, [], num_shards=self._num_cores_per_host) + return fn(*args, **kwargs) + + def get_initialization_ops(self): + return [tpu.initialize_system()] - return control_flow_ops.group(tpu_result, enqueue_ops) + def get_finalize_ops(self): + return [tpu.shutdown_system()] def _reduce(self, method_string, value, destinations): del destinations # TPU is graph mode only. Rely on implicit Send/Recv. diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index aca544b7e7e3c6f706377de9846881bea19b92d0..95390041f45a6dc9111454f2318cdff5aff017ed 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -23,10 +23,8 @@ from __future__ import print_function import collections import weakref - import six -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context @@ -35,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver @@ -43,7 +42,7 @@ from tensorflow.python.util import nest # pylint: disable=line-too-long -# TODO(josh11b): Should device values be strings or DeviceSpec objects +# TODO(josh11b): Should device values be strings or DeviceSpec objects? # Not sure DeviceSpec objects are usable as a dict key. class DistributedValues(object): """Holds a map from device to values. Either PerDevice or Mirrored.""" @@ -163,9 +162,16 @@ class PerDevice(DistributedValues): pass -class Mirrored(DistributedValues): +# Note that unlike PerDevice, Mirrored values inherit from +# DistributedDelegate and so can be used directly in cross-tower mode. +class Mirrored(DistributedDelegate): """Holds a map from device to values which are kept in sync.""" - pass + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return list(self._index.values())[0] def _assign_on_device(device, variable, tensor): @@ -186,6 +192,10 @@ class DistributedVariable(DistributedDelegate): # Child class must set self._primary_var before calling # super(...).__init__(index). self._common_name = self._primary_var.name.split(":")[0] + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._distributed_container = weakref.ref(self) # pylint: disable=protected-access super(DistributedVariable, self).__init__(index) @property @@ -241,21 +251,6 @@ class DistributedVariable(DistributedDelegate): ops.register_dense_tensor_like_type(DistributedVariable) -class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): - """Class for defining how to restore a MirroredVariable.""" - - def __init__(self, mirrored_variable, primary_variable, name): - self._mirrored_variable = mirrored_variable - super(_MirroredSaveable, self).__init__(primary_variable, "", name) - - def restore(self, restored_tensors, restored_shapes): - """Restore the same value into all variables.""" - tensor, = restored_tensors - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access - - def _get_update_device(): """Validate we are in update/update_non_slot() and return current device. @@ -276,34 +271,82 @@ def _get_update_device(): return device +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + class MirroredVariable(DistributedVariable, Mirrored, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are kept in sync.""" - def __init__(self, index, primary_var): + def __init__(self, index, primary_var, aggregation_method=None): # Use a weakref to make it easy to map from the contained values # to the container without introducing a reference cycle. for v in six.itervalues(index): v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var + self._aggregation_method = aggregation_method super(MirroredVariable, self).__init__(index) - # We use _get_update_device() for the assign* methods to enforce - # that we are in an update() function. The arguments to update() are - # automatically unwrapped so the update() function would normally - # see regular variables, not MirroredVariables. However, the update - # function can still operate on wrapped MirroredVariables through - # object members, captured arguments, etc. This is more likely in an + # The arguments to update() are automatically unwrapped so the update() + # function would normally see regular variables, not MirroredVariables. + # However, the update function can still operate on wrapped MirroredVariables + # through object members, captured arguments, etc. This is more likely in an # update_non_slot() function (like OptimizerV2._finish), which can # update several non-slot variables in one call. + def _assign_func(self, *args, **kwargs): + f = kwargs.pop("f") + if distribute_lib.get_cross_tower_context(): + update_device = distribute_lib.get_update_device() + # We are calling update on the mirrored variable in cross tower context. + if update_device is not None: + # We are calling an assign function on the mirrored variable in cross + # tower context. + v = self.get(device=update_device) + return f(v, *args, **kwargs) + + return distribute_lib.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + # We are calling an assign function on the mirrored variable in tower + # context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function on each of the mirrored variables with the reduced + # value. + if not self._aggregation_method: + raise ValueError("You must specify an aggregation method to update a " + "MirroredVariable in Tower Context.") + + def merge_fn(strategy, value): + return strategy.update(self, + f, + strategy.reduce( + method_string=self._aggregation_method, + value=value, + destinations=self)) + return distribute_lib.get_tower_context().merge_call(merge_fn, *args, + **kwargs) + def assign_sub(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + return self._assign_func(f=state_ops.assign_sub, *args, **kwargs) def assign_add(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + return self._assign_func(f=state_ops.assign_add, *args, **kwargs) def assign(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign(*args, **kwargs) + return self._assign_func(f=state_ops.assign, *args, **kwargs) def _get_cross_tower(self): device = device_util.canonicalize(device_util.current()) @@ -353,7 +396,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().fetch( + return distribute_lib.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -492,40 +535,40 @@ def regroup(per_device, wrap_class=PerDevice): same_id = False break # Consider three cases where same_id is true: - # * If v0 is a MirroredVariable (and same_id means it is the same - # across all devices), we want to return it. We check - # MirroredVariable specifically since it can look like it - # has a _mirrored_container member since its members do. - # * If v0 is a member of a mirrored variable, in which case - # hasattr(v0, "_mirrored_container") is true, we want to - # return the MirroredVariable that contains it using the - # _mirrored_container logic below. This case can trigger + # * If v0 is a DistributedVariable (a MirroredVariable or + # TowerLocalVariable, and same_id means it is the same across all + # devices), we want to return it. We check DistributedVariable + # specifically since it can look like it has a + # _distributed_container member since its members do. + # * If v0 is a member of a distributed variable, in which case + # hasattr(v0, "_distributed_container") is true, we want to + # return the DistributedVariable that contains it using the + # _distributed_container logic below. This case can trigger # same_id when there is only one device. # * In any other situation, same_id means we return v0. - if same_id and (isinstance(v0, MirroredVariable) or - not hasattr(v0, "_mirrored_container")): + if same_id and (isinstance(v0, DistributedVariable) or + not hasattr(v0, "_distributed_container")): return v0 # Detect the case where each device has a parallel component of the - # same MirroredVariable. In this case we want to return the - # containing MirroredVariable, after a bunch of sanity checking. - # In particular, each component should have the same container, - # and the devices of the variables should match the keys of the - # per-device dictionary. - # TODO(josh11b): Do we need similar logic for TowerLocalVariables? - if hasattr(v0, "_mirrored_container"): + # same MirroredVariable (or TowerLocalVariable). In this case we + # want to return the containing MirroredVariable, after a bunch of + # sanity checking. In particular, each component should have the + # same container, and the devices of the variables should match the + # keys of the per-device dictionary. + if hasattr(v0, "_distributed_container"): # pylint: disable=protected-access assert not isinstance(v0, MirroredVariable), ( "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) assert _devices_match(v0.device, items[0][0]), ( "v0.device = %s, items = %s" % (v0.device, items)) - mirrored_container = v0._mirrored_container() - assert mirrored_container is not None + distributed_container = v0._distributed_container() + assert distributed_container is not None for d, v in items[1:]: assert _devices_match(v.device, d), ( "v.device = %s, d = %s, items = %s" % (v.device, d, items)) - assert mirrored_container is v._mirrored_container() - return mirrored_container + assert distributed_container is v._distributed_container() + return distributed_container # pylint: enable=protected-access return wrap_class(per_device) @@ -607,8 +650,7 @@ class PerDeviceDataset(object): # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. - self._dataset = dataset.apply( - batching.batch_and_drop_remainder(len(devices))) + self._dataset = dataset.batch(len(devices), drop_remainder=True) def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" @@ -819,3 +861,72 @@ class MapOutput(object): def get(self): return self._l + + +class MultiStepContext(object): + """A context object that can be used to capture things when running steps. + + This context object is useful when running multiple steps at a time using the + `run_steps_on_dataset` API. For e.g. it allows the user's step function to + specify which outputs to emit at what frequency. Currently it only supports + capturing output from the last step, but will soon be augmented to support + other use cases such as output each N steps. + """ + + def __init__(self, initial_loop_values=None): + """Initializes an output context. + + Args: + initial_loop_values: Initial values passed to the run steps + while loop. The only purpose is to verify the shapes and types + when the actual output is set. This will be removed once we + automatically infer the output shapes and types (and do not need to + check for user error in specifying them manually). + Returns: + A context object. + """ + self._last_step_outputs = None + self._non_tensor_outputs = None + self._initial_loop_values = initial_loop_values + + @property + def last_step_outputs(self): + """Return the last step's outputs.""" + return self._last_step_outputs + + @last_step_outputs.setter + def last_step_outputs(self, outputs): + """Set the last step's outputs.""" + self._verify_structure_shapes_types(outputs, self._initial_loop_values) + self._last_step_outputs = outputs + + @property + def non_tensor_outputs(self): + """Return the non tensor outputs.""" + return self._non_tensor_outputs + + @non_tensor_outputs.setter + def non_tensor_outputs(self, outputs): + """Set any non tensor outputs.""" + self._non_tensor_outputs = outputs + + def _verify_structure_shapes_types(self, left, right): + """Verify that the structure, shapes and types of left are same as right.""" + nest.assert_same_structure(left, right) + flat_left = nest.flatten(left) + flat_right = nest.flatten(right) + assert len(flat_left) == len(flat_right), ( + "Length of left {} and right {} should be same.". + format(len(flat_left), len(flat_right))) + + for o, i in zip(flat_left, flat_right): + # TODO(priyag): Add checks for other types like IndexedSlices. + if isinstance(o, ops.Tensor): + assert isinstance(i, ops.Tensor) + assert o.shape == i.shape, ( + "Shape {} of left {} doesn't match shape {} of right {}.". + format(o.shape, o, i.shape, i)) + assert o.dtype == i.dtype, ( + "Dtype {} of left {} doesn't match dtype {} of right {}.". + format(o.dtype, o, i.dtype, i)) + diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index b0bd92c7b054b52b071e5d7601bdc48117464822..c5b246e8041500e478478d1bb1527c3fe752b377 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -82,7 +82,7 @@ class DistributedValuesTest(test.TestCase): class DistributedDelegateTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetAttr(self): with ops.device("/device:CPU:0"): @@ -97,7 +97,7 @@ class DistributedDelegateTest(test.TestCase): with self.assertRaises(AttributeError): _ = v.y - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOperatorOverride(self): with ops.device("/device:CPU:0"): v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) @@ -363,7 +363,7 @@ class PerDeviceDatasetTest(test.TestCase): self._test_iterator_no_prefetch(devices, dataset, expected_values) self._test_iterator_with_prefetch(devices, dataset, expected_values) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOneDevice(self): devices = ["/device:CPU:0"] dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index e281e81bdf0698c1f7b2f60fb27783dd1351773f..d1ce273499c8a646c0757844c91a785fa8d56ce4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -61,6 +61,28 @@ class CholeskyOuterProductBijectorTest(test.TestCase): atol=0., rtol=1e-7) + def testNoBatchStaticJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=2))) + + def testNoBatchDynamicJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + x_pl = array_ops.placeholder(dtypes.float32) + + with self.test_session(): + log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2) + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + log_det_jacobian.eval({x_pl: x})) + def testNoBatchStatic(self): x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py index caeaf2a0c6e4fff28c0edd82cb09ca0bcee85fc3..3530e142e4d1545e80a3b1bf1e8ddbf7819ba58a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test class FillTriangularBijectorTest(test.TestCase): """Tests the correctness of the FillTriangular bijector.""" - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijector(self): x = np.float32(np.array([1., 2., 3.])) y = np.float32(np.array([[3., 0.], @@ -51,7 +51,7 @@ class FillTriangularBijectorTest(test.TestCase): ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) self.assertAllClose(ildj, 0.) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): x_shape = tensor_shape.TensorShape([5, 4, 6]) y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) @@ -76,7 +76,7 @@ class FillTriangularBijectorTest(test.TestCase): b.inverse_event_shape_tensor(y_shape.as_list())) self.assertAllEqual(x_shape_tensor, x_shape.as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShapeError(self): b = bijectors.FillTriangular(validate_args=True) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py index 18397035571561731698b06d90e20dc74e3cf83c..85d604e34ac25cf94b601470b7f166d9d414a8e3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class MatrixInverseTriLBijectorTest(test.TestCase): """Tests the correctness of the Y = inv(tril) transformation.""" - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputesCorrectValues(self): inv = bijectors.MatrixInverseTriL(validate_args=True) self.assertEqual("matrix_inverse_tril", inv.name) @@ -51,7 +51,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOneByOneMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[5.]], dtype=np.float32) @@ -70,7 +70,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testZeroByZeroMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.eye(0, dtype=np.float32) @@ -89,7 +89,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBatch(self): # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape # (2, 1). @@ -114,7 +114,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testErrorOnInputRankTooLow(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([0.1], dtype=np.float32) @@ -149,7 +149,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): ## square_error_msg): ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testErrorOnInputNotLowerTriangular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 2.], @@ -169,7 +169,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): triangular_error_msg): inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testErrorOnInputSingular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 0.], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py index a5f5219588fb3be67beb797ba68ed8148e9e9fd2..cb42331a21a6acdd5244c311a7def5359bb6c574 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py @@ -36,7 +36,7 @@ class OrderedBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorVector(self): with self.test_session(): ordered = Ordered() @@ -82,7 +82,7 @@ class OrderedBijectorTest(test.TestCase): atol=0., rtol=1e-7) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShapeGetters(self): with self.test_session(): x = tensor_shape.TensorShape([4]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py index 566a7b3dff9b5d97a1cb143e0b32fc15984c3a02..d5b3367f9a31a9c602e0b138e617db68834b8229 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -46,7 +46,7 @@ class ScaleTriLBijectorTest(test.TestCase): x_ = self.evaluate(b.inverse(y)) self.assertAllClose(x, x_) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvertible(self): # Generate random inputs from an unconstrained space, with diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py index 2ac06fce55b448a5f3da7ccb7f8766b5b1404ad7..d0098c3c105626da1da5855710169069ebeffbd9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -40,7 +40,7 @@ class SoftsignBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorBounds(self): bijector = Softsign(validate_args=True) with self.test_session(): @@ -54,7 +54,7 @@ class SoftsignBijectorTest(test.TestCase): with self.assertRaisesOpError("less than 1"): bijector.inverse_log_det_jacobian(3., event_ndims=0).eval() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverse(self): bijector = Softsign(validate_args=True) self.assertEqual("softsign", bijector.name) @@ -64,7 +64,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorLogDetJacobianEventDimsZero(self): bijector = Softsign(validate_args=True) y = self._rng.rand(2, 10) @@ -74,7 +74,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(ildj, self.evaluate( bijector.inverse_log_det_jacobian(y, event_ndims=0))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverseEventDimsOne(self): bijector = Softsign(validate_args=True) self.assertEqual("softsign", bijector.name) @@ -83,7 +83,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorLogDetJacobianEventDimsOne(self): bijector = Softsign(validate_args=True) y = self._rng.rand(2, 10) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py index 6428a68702274fae384ae3de6d03f7ca126e2346..efc9f266d1fb6bcc53ae318e218b0697825c0155 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -31,7 +31,7 @@ class TransformDiagonalBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijector(self): x = np.float32(np.random.randn(3, 4, 4)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index bbbec2103aefd3f38a9b734bcd3f2e15fc8bb683..181c46d2e52552e641bc59c0fe94743f1af42845 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -544,7 +544,7 @@ class PadDynamicTest(_PadTest, test.TestCase): class TestMoveDimension(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_move_dimension_static_shape(self): x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) @@ -561,7 +561,7 @@ class TestMoveDimension(test.TestCase): x_perm = distribution_util.move_dimension(x, 4, 2) self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_move_dimension_dynamic_shape(self): x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD index 03e26b198ea02ad1bef8bcd2f6076078ecd7df0b..42ecea034d77430924bd6f597bf42ec3f64fec92 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -34,7 +34,10 @@ py_test( name = "correlation_matrix_volumes_test", size = "medium", srcs = ["correlation_matrix_volumes_test.py"], - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ ":correlation_matrix_volumes_py", # For statistical testing diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 8267ee7df89f69f8d610e9507e0cca9f4a5d4323..3e1e4fc82971b71792d193ea8518dd402e4a4d9d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -182,7 +182,20 @@ class CholeskyOuterProduct(bijector.Bijector): axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag - return fldj + # We finally need to undo adding an extra column in non-scalar cases + # where there is a single matrix as input. + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 2: + fldj = array_ops.squeeze(fldj, axis=-1) + return fldj + + shape = array_ops.shape(fldj) + maybe_squeeze_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 2), + np.array([], dtype=np.int32), shape[-1:])], 0) + return array_ops.reshape(fldj, maybe_squeeze_shape) def _make_columnar(self, x): """Ensures non-scalar input has at least one column. diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 0c762f17c9b770ecada57b6ce60a4825ba374dd9..214c6dca4a7f2b4cd6242e1b7ca78be9eeffb851 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -235,7 +235,7 @@ class OneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 9b5bd7576f2a3c364e21da76dd3905a8c6e35829..25aaac379a7c54c832bdcf962e16f339522d61fc 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -299,7 +299,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 4384431e7b9c3e6ef259391fa9efa5a35d23c86a..86d203452e24d6d73f3ebb17b989867905a61382 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -44,7 +44,7 @@ Installation instructions at https://www.tensorflow.org/install/ For an introduction to eager execution in TensorFlow, see: -- [User Guide](https://www.tensorflow.org/programmers_guide/eager) ([source](../../docs_src/programmers_guide/eager.md)) +- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md)) - Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) - Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) - Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index adf92c27ea0a27c5741bcdd175b277462cb28d02..58c548d798178a2848006cbf301f7d5cb2143f24 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -102,6 +102,7 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): with ops.device(self._device): self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long string_arg=iter_string_handle, + output_types=self._flat_output_types, f=remote_fn, target_device=target, buffer_size=10, diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 6f02c90368d966b8cf8d0dee09f9d2a5013c90c1..12155a459c29c353c57679c407e7dda25047a35c 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -15,6 +15,8 @@ py_library( "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/sagan", + "//tensorflow/contrib/eager/python/examples/sagan:config", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb index bfcc7feb075c403d024772e0d715339d58877a51..d268cbcd9171b0f4a4f2ab27ad958374e521685b 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb @@ -9,7 +9,7 @@ "source": [ "# Eager Execution Tutorial: Importing Data\n", "\n", - "This notebook demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n", + "This notebook demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your program. It covers:\n", "\n", "* Creating a `Dataset`.\n", "* Iteration over a `Dataset` with eager execution enabled.\n", @@ -18,7 +18,7 @@ "\n", "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", - "As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." + "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled." ] }, { @@ -63,7 +63,7 @@ "source": [ "# Step 1: Create a source `Dataset`\n", "\n", - "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information." + "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information." ] }, { diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 0c0e28dd95c68dc300384a128eb5aa2208f63a0d..68a84d5fbb4f13e4ebe0d71e3f5caebe97e2101c 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -51,5 +51,6 @@ cuda_py_test( "noasan", "nomsan", "notsan", + "optonly", ], ) diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index bfb53cfff86650c28fdd934763b1fb40cc5c796c..432bb546f83932d0e0a465d7af7c641b60d2e564 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -62,6 +62,9 @@ cuda_py_test( ":blocks", "//tensorflow:tensorflow_py", ], + tags = [ + "optonly", + ], ) cuda_py_test( @@ -73,4 +76,39 @@ cuda_py_test( ":revnet", "//tensorflow:tensorflow_py", ], + tags = [ + "optonly", + ], +) + +# Training +py_library( + name = "cifar_input", + srcs = ["cifar_input.py"], + srcs_version = "PY2AND3", + deps = [ + ":revnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar_tfrecords", + srcs = ["cifar_tfrecords.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "main", + srcs = ["main.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..21fc44febc8abdc30daad1b35d8434b083360bdf --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/README.md @@ -0,0 +1,45 @@ +# RevNet with TensorFlow eager execution + +This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster. + +## Content + +- `revnet.py`: The RevNet model. +- `blocks.py`: The relevant reversible blocks. +- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100. +- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API. +- `config.py`: Configuration file for network architectures and training hyperparameters. +- `main.py`: Main training and evaluation script. +- `ops.py`: Auxiliary downsampling operation. + +## To run +- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- First run + +```bash +python cifar_tfrecords.py --data_dir ${PWD}/cifar +``` +to download the cifar dataset and convert them +to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100. + +- To train a model run + +```bash +python main.py --data_dir ${PWD}/cifar +``` + +- Optional arguments for `main.py` include + - `train_dir`: Directory to store eventfiles and checkpoints. + - `restore`: Restore the latest checkpoint. + - `validate`: Use validation set for training monitoring. + - `manual_grad`: Use the manually defined gradient map given by the authors. + - `dataset`: Use either `cifar-10` or `cifar-100` + +## Performance +- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100. + +## Reference +The Reversible Residual Network: Backpropagation Without Storing Activations. +Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse. Neural Information Processing Systems (NIPS), 2017. diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index fb4f9f068f062802cda4610ced01c50da3836e04..74c1825a49a702c8c4cc8ec04ebb87917bca380d 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -24,6 +24,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import ops @@ -93,9 +94,18 @@ class RevBlock(tf.keras.Model): for i in reversed(range(len(self.blocks))): block = self.blocks[i] - y_inv = x if i == 0 else block.backward(y, training=training) + if i == 0: + y_inv = x + else: + # Don't update running stats when reconstructing activations + vars_and_vals = block.get_moving_stats() + y_inv = block.backward(y, training=training) + block.restore_moving_stats(vars_and_vals) + + # Update running stats when computing gradients during training dy, grads, vars_ = block.backward_grads_and_vars( y_inv, dy, training=training) + grads_all += grads vars_all += vars_ @@ -159,17 +169,18 @@ class _Residual(tf.keras.Model): """Apply residual block to inputs.""" x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) - f_x2 = self.f.call(x2, training=training) + f_x2 = self.f(x2, training=training) # TODO(lxuechen): Replace with simpler downsampling x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) x2_down = ops.downsample( x2, self.filters // 2, self.strides, axis=self.axis) y1 = f_x2 + x1_down - g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error + g_y1 = self.g(y1, training=training) y2 = g_y1 + x2_down - if not concat: # Concat option needed for correct backward grads + if not concat: # For correct backward grads return y1, y2 + return tf.concat([y1, y2], axis=self.axis) def backward(self, y, training=True): @@ -178,9 +189,9 @@ class _Residual(tf.keras.Model): assert self.strides == (1, 1) y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) - g_y1 = self.g.call(y1, training=training) + g_y1 = self.g(y1, training=training) x2 = y2 - g_y1 - f_x2 = self.f.call(x2, training=training) + f_x2 = self.f(x2, training=training) x1 = y1 - f_x2 return tf.concat([x1, x2], axis=self.axis) @@ -189,8 +200,8 @@ class _Residual(tf.keras.Model): """Manually compute backward gradients given input and output grads.""" with tf.GradientTape(persistent=True) as tape: - x_stop = tf.stop_gradient(x) - x1, x2 = tf.split(x_stop, num_or_size_splits=2, axis=self.axis) + x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed + x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) tape.watch([x1, x2]) # Stitch back x for `call` so tape records correct grads x = tf.concat([x1, x2], axis=self.axis) @@ -200,22 +211,38 @@ class _Residual(tf.keras.Model): x2, self.filters // 2, self.strides, axis=self.axis) grads_combined = tape.gradient( - y2, [y1] + self.g.variables, output_gradients=[dy2]) + y2, [y1] + self.g.trainable_variables, output_gradients=[dy2]) dy2_y1, dg = grads_combined[0], grads_combined[1:] dy1_plus = dy2_y1 + dy1 grads_combined = tape.gradient( - y1, [x1, x2] + self.f.variables, output_gradients=[dy1_plus]) + y1, [x1, x2] + self.f.trainable_variables, output_gradients=[dy1_plus]) dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:] dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0] del tape grads = df + dg - vars_ = self.f.variables + self.g.variables + vars_ = self.f.trainable_variables + self.g.trainable_variables return tf.concat([dx1, dx2], axis=self.axis), grads, vars_ + def get_moving_stats(self): + vars_and_vals = {} + + def _is_moving_var(v): # pylint: disable=invalid-name + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + + for v in filter(_is_moving_var, self.f.variables + self.g.variables): + vars_and_vals[v] = v.read_value() + + return vars_and_vals + + def restore_moving_stats(self, vars_and_vals): + for var_, val in six.iteritems(vars_and_vals): + var_.assign(val) + def _BottleneckResidualInner(filters, strides, @@ -246,7 +273,7 @@ def _BottleneckResidualInner(filters, model.add( tf.keras.layers.BatchNormalization( axis=axis, input_shape=input_shape, fused=fused)) - model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add(tf.keras.layers.Activation("relu")) model.add( tf.keras.layers.Conv2D( filters=filters // 4, @@ -258,7 +285,7 @@ def _BottleneckResidualInner(filters, padding="SAME")) model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) - model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add(tf.keras.layers.Activation("relu")) model.add( tf.keras.layers.Conv2D( filters=filters // 4, @@ -269,7 +296,7 @@ def _BottleneckResidualInner(filters, padding="SAME")) model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) - model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add(tf.keras.layers.Activation("relu")) model.add( tf.keras.layers.Conv2D( filters=filters, @@ -310,7 +337,7 @@ def _ResidualInner(filters, model.add( tf.keras.layers.BatchNormalization( axis=axis, input_shape=input_shape, fused=fused)) - model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add(tf.keras.layers.Activation("relu")) model.add( tf.keras.layers.Conv2D( filters=filters, @@ -322,7 +349,7 @@ def _ResidualInner(filters, padding="SAME")) model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) - model.add(tf.keras.layers.LeakyReLU(alpha=0.)) + model.add(tf.keras.layers.Activation("relu")) model.add( tf.keras.layers.Conv2D( filters=filters, diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py index f4436fd92506d54f1206fbfd424b897f9835657d..a28ca6e3e076ef1d52ab5a34e5559536cf5d52cc 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -240,13 +240,12 @@ class _ResidualTest(tf.test.TestCase): x = tf.random_normal(shape=data_shape) residual = blocks._Residual( filters=16, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = residual(x, training=True), residual(x, training=False) - x_ = residual.backward(y_tr, training=True) - # The numerical loss is alarming; reconstructed inputs could differ from - # the original inputs often by more than 1e-3 - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) x_ = residual.backward(y_ev, training=False) - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) + x_ = residual.backward(y_tr, training=True) # This updates moving avg + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) def test_backward_channels_last(self): """Test `backward` function with `channels_last` data format.""" @@ -259,12 +258,12 @@ class _ResidualTest(tf.test.TestCase): strides=(1, 1), input_shape=input_shape, data_format="channels_last") + y_tr, y_ev = residual(x, training=True), residual(x, training=False) - x_ = residual.backward(y_tr, training=True) - # Egregious numerical error - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) x_ = residual.backward(y_ev, training=False) - self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) + x_ = residual.backward(y_tr, training=True) # This updates moving avg + self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) def test_backward_grads_and_vars_channels_first(self): """Test `backward_grads` function with `channels_first` data format.""" @@ -278,6 +277,8 @@ class _ResidualTest(tf.test.TestCase): dy = tf.random_normal(shape=data_shape) residual = blocks._Residual( filters=16, strides=(1, 1), input_shape=input_shape) + + vars_and_vals = residual.get_moving_stats() dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( x, dy=dy, training=True) dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( @@ -289,10 +290,23 @@ class _ResidualTest(tf.test.TestCase): self.assertTrue(isinstance(vars_ev, list)) for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, vars_ev): - if grad_tr is not None: # Batch norm moving mean, var gives None grad - self.assertEqual(grad_tr.shape, grad_ev.shape) - self.assertEqual(var_tr.shape, var_ev.shape) - self.assertEqual(grad_tr.shape, var_tr.shape) + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) + + # Compare against the true gradient computed by the tape + residual.restore_moving_stats(vars_and_vals) + with tf.GradientTape(persistent=True) as tape: + tape.watch(x) + y = residual(x, training=True) + grads = tape.gradient( + y, [x] + residual.trainable_variables, output_gradients=[dy]) + dx_tr_true, grads_tr_true = grads[0], grads[1:] + + del tape + + self.assertAllClose(dx_tr, dx_tr_true, rtol=1e-1, atol=1e-1) + self.assertAllClose(grads_tr, grads_tr_true, rtol=1e-1, atol=1e-1) def test_backward_grads_and_vars_channels_last(self): """Test `backward_grads` function with `channels_last` data format.""" @@ -306,6 +320,7 @@ class _ResidualTest(tf.test.TestCase): strides=(1, 1), input_shape=input_shape, data_format="channels_last") + dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( x, dy=dy, training=True) dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( @@ -317,10 +332,9 @@ class _ResidualTest(tf.test.TestCase): self.assertTrue(isinstance(vars_ev, list)) for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, vars_ev): - if grad_tr is not None: # Batch norm moving mean, var gives None grad - self.assertEqual(grad_tr.shape, grad_ev.shape) - self.assertEqual(var_tr.shape, var_ev.shape) - self.assertEqual(grad_tr.shape, var_tr.shape) + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) class _ResidualInnerTest(tf.test.TestCase): diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d8b3a0559704bd8f00a8cc4b9fe735ad1de5f9 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -0,0 +1,116 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script for reading and loading CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +# Global constants describing the CIFAR data set. +IMAGE_HEIGHT = 32 +IMAGE_WIDTH = 32 +NUM_CHANNEL = 3 + + +def get_ds_from_tfrecords(data_dir, + split, + data_aug=True, + batch_size=100, + epochs=None, + shuffle=True, + data_format="channels_first", + num_parallel_calls=8, + prefetch=0, + div255=True, + dtype=tf.float32): + """Returns a tf.train.Dataset object from reading tfrecords. + + Args: + data_dir: Directory of tfrecords + split: "train", "validation", or "test" + data_aug: Apply data augmentation if True + batch_size: Batch size of dataset object + epochs: Number of epochs to repeat the dataset; default `None` means + repeating indefinitely + shuffle: Shuffle the dataset if True + data_format: `channels_first` or `channels_last` + num_parallel_calls: Number of threads for dataset preprocess + prefetch: Buffer size for prefetch + div255: Divide the images by 255 if True + dtype: Data type of images + Returns: + A tf.train.Dataset object + + Raises: + ValueError: Unknown split + """ + + if split not in ["train", "validation", "test", "train_all"]: + raise ValueError("Unknown split {}".format(split)) + + def _parser(serialized_example): + """Parses a single tf.Example into image and label tensors.""" + features = tf.parse_single_example( + serialized_example, + features={ + "image": tf.FixedLenFeature([], tf.string), + "label": tf.FixedLenFeature([], tf.int64), + }) + image = tf.decode_raw(features["image"], tf.uint8) + # Initially reshaping to [H, W, C] does not work + image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH]) + # This is needed for `tf.image.resize_image_with_crop_or_pad` + image = tf.transpose(image, [1, 2, 0]) + + image = tf.cast(image, dtype) + label = tf.cast(features["label"], tf.int32) + + if data_aug: + image = tf.image.resize_image_with_crop_or_pad(image, IMAGE_HEIGHT + 4, + IMAGE_WIDTH + 4) + image = tf.random_crop(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL]) + image = tf.image.random_flip_left_right(image) + + if data_format == "channels_first": + image = tf.transpose(image, [2, 0, 1]) + + if div255: + image /= 255. + + return image, label + + filename = os.path.join(data_dir, split + ".tfrecords") + dataset = tf.data.TFRecordDataset(filename) + dataset = dataset.repeat(epochs) + dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls) + dataset = dataset.prefetch(prefetch) + + if shuffle: + # Find the right size according to the split + size = { + "train": 40000, + "validation": 10000, + "test": 10000, + "train_all": 50000 + }[split] + dataset = dataset.shuffle(size) + + dataset = dataset.batch(batch_size) + + return dataset diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..f79428b2a97f0ac2ce991f4c26b9123cddc24325 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords. + +Generates tf.train.Example protos and writes them to TFRecord files from the +python version of the CIFAR-10 dataset downloaded from +https://www.cs.toronto.edu/~kriz/cifar.html. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import tarfile + +from absl import flags +from six.moves import cPickle as pickle +from six.moves import urllib +import tensorflow as tf + +CIFAR_FILENAME = 'cifar-10-python.tar.gz' +CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME +CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py' + + +def download_and_extract(data_dir): + """Download CIFAR-10 if not already downloaded.""" + filepath = os.path.join(data_dir, CIFAR_FILENAME) + if tf.gfile.Exists(filepath): + return filepath + if not tf.gfile.Exists(data_dir): + tf.gfile.MakeDirs(data_dir) + + urllib.request.urlretrieve(CIFAR_DOWNLOAD_URL, filepath) + tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir) + return filepath + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _get_file_names(): + """Returns the file names expected to exist in the input_dir.""" + file_names = {} + file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] + file_names['validation'] = ['data_batch_5'] + file_names['test'] = ['test_batch'] + return file_names + + +def read_pickle_from_file(filename): + with tf.gfile.Open(filename, 'rb') as f: + if sys.version_info >= (3, 0): + data_dict = pickle.load(f, encoding='bytes') + else: + data_dict = pickle.load(f) + return data_dict + + +def convert_to_tfrecord(input_files, output_file): + """Converts files with pickled data to TFRecords.""" + print('Generating %s' % output_file) + with tf.python_io.TFRecordWriter(output_file) as record_writer: + for input_file in input_files: + data_dict = read_pickle_from_file(input_file) + data = data_dict[b'data'] + labels = data_dict[b'labels'] + num_entries_in_batch = len(labels) + + for i in range(num_entries_in_batch): + example = tf.train.Example( + features=tf.train.Features( + feature={ + 'image': _bytes_feature(data[i].tobytes()), + 'label': _int64_feature(labels[i]) + })) + record_writer.write(example.SerializeToString()) + + +def main(_): + print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL)) + download_and_extract(FLAGS.data_dir) + file_names = _get_file_names() + input_dir = os.path.join(FLAGS.data_dir, CIFAR_LOCAL_FOLDER) + + for mode, files in file_names.items(): + input_files = [os.path.join(input_dir, f) for f in files] + output_file = os.path.join(FLAGS.data_dir, mode + '.tfrecords') + try: + os.remove(output_file) + except OSError: + pass + convert_to_tfrecord(input_files, output_file) + print('Done!') + + +if __name__ == '__main__': + FLAGS = flags.FLAGS + flags.DEFINE_string( + 'data_dir', + default=None, + help='Directory to download and extract CIFAR-10 to.') + + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py index 495a78d550a48fa56d6cfa276e47c9ff846edff3..30b0edbf43304f4dd1b3a10165bdb28886d2d152 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/config.py +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -27,6 +27,7 @@ from __future__ import division from __future__ import print_function import tensorflow as tf +tfe = tf.contrib.eager def get_hparams_cifar_38(): @@ -41,11 +42,11 @@ def get_hparams_cifar_38(): config.add_hparam("n_res", [3, 3, 3]) config.add_hparam("filters", [32, 64, 112]) config.add_hparam("strides", [1, 2, 2]) - config.add_hparam("batch_size", 10) + config.add_hparam("batch_size", 100) config.add_hparam("bottleneck", False) config.add_hparam("fused", True) config.add_hparam("init_max_pool", False) - if tf.test.is_gpu_available(): + if tfe.num_gpus() > 0: config.add_hparam("input_shape", (3, 32, 32)) config.add_hparam("data_format", "channels_first") else: @@ -60,13 +61,15 @@ def get_hparams_cifar_38(): config.add_hparam("max_train_iter", 80000) config.add_hparam("seed", 1234) config.add_hparam("shuffle", True) - config.add_hparam("prefetch", True) - config.add_hparam("print_every", 50) + config.add_hparam("log_every", 500) + config.add_hparam("save_every", 500) config.add_hparam("dtype", tf.float32) - config.add_hparam("eval_batch_size", 500) + config.add_hparam("eval_batch_size", 1000) config.add_hparam("div255", True) - # For tf.data.Dataset - config.add_hparam("epochs", config.max_train_iter // config.batch_size) + # TODO(lxuechen): This is imprecise, when training with validation set, + # we only have 40k images in training data + config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) return config @@ -102,13 +105,14 @@ def get_hparams_imagenet_56(): config.add_hparam("max_train_iter", 600000) config.add_hparam("seed", 1234) config.add_hparam("shuffle", True) - config.add_hparam("prefetch", True) - config.add_hparam("print_every", 50) + config.add_hparam("log_every", 50) + config.add_hparam("save_every", 50) config.add_hparam("dtype", tf.float32) - config.add_hparam("eval_batch_size", 500) + config.add_hparam("eval_batch_size", 1000) config.add_hparam("div255", True) - # For tf.data.Dataset - config.add_hparam("epochs", config.max_train_iter // config.batch_size) + # TODO(lxuechen): Update this according to ImageNet data + config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) if config.bottleneck: filters = [f * 4 for f in config.filters] diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py new file mode 100644 index 0000000000000000000000000000000000000000..106559250940acba1a7bb600283e25dae6252e4b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -0,0 +1,230 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Eager execution workflow with RevNet train on CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl import flags +import tensorflow as tf +from tqdm import tqdm +from tensorflow.contrib.eager.python.examples.revnet import cifar_input +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +tfe = tf.contrib.eager + + +def main(_): + """Eager execution workflow with RevNet trained on CIFAR-10.""" + if FLAGS.data_dir is None: + raise ValueError("No supplied data directory") + + if not os.path.exists(FLAGS.data_dir): + raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) + + tf.enable_eager_execution() + config = config_.get_hparams_cifar_38() + + if FLAGS.validate: + # 40k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="train", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + # 10k Training set + ds_validation = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="validation", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + else: + # 50k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="train_all", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + + # Always compute loss and accuracy on whole training and test set + ds_train_one_shot = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="train_all", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + + ds_test = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="test", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + + model = revnet.RevNet(config=config) + global_step = tfe.Variable(1, trainable=False) + learning_rate = tf.train.piecewise_constant( + global_step, config.lr_decay_steps, config.lr_list) + optimizer = tf.train.MomentumOptimizer( + learning_rate, momentum=config.momentum) + checkpointer = tf.train.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=global_step) + + if FLAGS.train_dir: + summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) + if FLAGS.restore: + latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) + checkpointer.restore(latest_path) + print("Restored latest checkpoint at path:\"{}\" " + "with global_step: {}".format(latest_path, global_step.numpy())) + sys.stdout.flush() + + warmup(model, config) + + for x, y in ds_train: + loss = train_one_iter(model, x, y, optimizer, global_step=global_step) + + if global_step.numpy() % config.log_every == 0: + it_train = ds_train_one_shot.make_one_shot_iterator() + acc_train, loss_train = evaluate(model, it_train) + it_test = ds_test.make_one_shot_iterator() + acc_test, loss_test = evaluate(model, it_test) + if FLAGS.validate: + it_validation = ds_validation.make_one_shot_iterator() + acc_validation, loss_validation = evaluate(model, it_validation) + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "validation set accuracy {:.4f}, loss {:4.f}" + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_validation, + loss_validation, acc_test, loss_test)) + else: + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_test, + loss_test)) + sys.stdout.flush() + + if FLAGS.train_dir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Training loss", loss) + tf.contrib.summary.scalar("Test accuracy", acc_test) + if FLAGS.validate: + tf.contrib.summary.scalar("Validation accuracy", acc_validation) + + if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: + saved_path = checkpointer.save( + file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) + print("Saved checkpoint at path: \"{}\" " + "with global_step: {}".format(saved_path, global_step.numpy())) + sys.stdout.flush() + + +def warmup(model, config, steps=1): + mock_input = tf.random_normal((config.batch_size,) + config.input_shape) + for _ in range(steps): + model(mock_input, training=False) + + +def train_one_iter(model, + inputs, + labels, + optimizer, + global_step=None, + verbose=False): + """Train for one iteration.""" + if FLAGS.manual_grad: + if verbose: + print("Using manual gradients") + grads, vars_, loss = model.compute_gradients(inputs, labels) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + else: # For correctness validation + if verbose: + print("Not using manual gradients") + with tf.GradientTape() as tape: + logits, _ = model(inputs, training=True) + loss = model.compute_loss(logits=logits, labels=labels) + grads = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + + return loss.numpy() + + +def evaluate(model, iterator): + """Compute accuracy with the given dataset iterator.""" + mean_loss = tfe.metrics.Mean() + accuracy = tfe.metrics.Accuracy() + for x, y in tqdm(iterator): + logits, _ = model(x, training=False) + loss = model.compute_loss(logits=logits, labels=y) + accuracy( + labels=tf.cast(y, tf.int64), + predictions=tf.argmax(logits, axis=1, output_type=tf.int64)) + mean_loss(loss) + + return accuracy.result().numpy(), mean_loss.result().numpy() + + +if __name__ == "__main__": + flags.DEFINE_string( + "train_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") + flags.DEFINE_boolean( + "restore", + default=False, + help="[Optional] Restore the latest checkpoint from `train_dir` if True") + flags.DEFINE_boolean( + "validate", + default=False, + help="[Optional] Use the validation set or not for hyperparameter search") + flags.DEFINE_boolean( + "manual_grad", + default=False, + help="[Optional] Use manual gradient graph to save memory") + FLAGS = flags.FLAGS + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index aa3f7efe1b6d8a44ce1bef065f24fa5c35cd404a..0228bff6fab9b9704bfea0836f06a6ec0ff7839e 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -27,14 +27,11 @@ from __future__ import print_function import functools import operator +import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks -# Global Conventions: -# 1) Default data format is NCWH, targeting GPU -# 2) Each block has attribute axis, inferred from data_format -# 3) Default training option to True for batch normalization class RevNet(tf.keras.Model): """RevNet that depends on all the blocks.""" @@ -51,6 +48,7 @@ class RevNet(tf.keras.Model): self._init_block = self._construct_init_block() self._block_list = self._construct_intermediate_blocks() self._final_block = self._construct_final_block() + self._moving_stats_vars = None def _construct_init_block(self): init_block = tf.keras.Sequential( @@ -65,7 +63,7 @@ class RevNet(tf.keras.Model): input_shape=self.config.input_shape), tf.keras.layers.BatchNormalization( axis=self.axis, fused=self.config.fused), - tf.keras.layers.LeakyReLU(alpha=0.) + tf.keras.layers.Activation("relu"), ], name="init") if self.config.init_max_pool: @@ -100,7 +98,7 @@ class RevNet(tf.keras.Model): axis=self.axis, input_shape=input_shape, fused=self.config.fused), - tf.keras.layers.LeakyReLU(alpha=0.), # Vanilla ReLU + tf.keras.layers.Activation("relu"), tf.keras.layers.GlobalAveragePooling2D( data_format=self.config.data_format), tf.keras.layers.Dense(self.config.n_classes) @@ -157,7 +155,6 @@ class RevNet(tf.keras.Model): def call(self, inputs, training=True): """Forward pass.""" - # Only store hidden states during training if training: saved_hidden = [inputs] @@ -185,17 +182,22 @@ class RevNet(tf.keras.Model): def compute_gradients(self, inputs, labels, training=True): """Manually computes gradients. + This method also SILENTLY updates the running averages of batch + normalization when `training` is set to True. + Args: inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` labels: One-hot labels for classification - training: for batch normalization + training: Use the mini-batch stats in batch norm if set to True Returns: - list of tuple each being (grad, var) for optimizer use + list of tuples each being (grad, var) for optimizer to use """ - # Forward pass record hidden states before downsampling + # Run forward pass to record hidden states; avoid updating running averages + vars_and_vals = self.get_moving_stats() _, saved_hidden = self.call(inputs, training=training) + self.restore_moving_stats(vars_and_vals) grads_all = [] vars_all = [] @@ -203,14 +205,17 @@ class RevNet(tf.keras.Model): # Manually backprop through last block x = saved_hidden[-1] with tf.GradientTape() as tape: + x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed tape.watch(x) + # Running stats updated below logits = self._final_block(x, training=training) - cost = self.compute_loss(logits, labels) + loss = self.compute_loss(logits, labels) - grads_combined = tape.gradient(cost, [x] + self._final_block.variables) + grads_combined = tape.gradient(loss, + [x] + self._final_block.trainable_variables) dy, grads_ = grads_combined[0], grads_combined[1:] grads_all += grads_ - vars_all += self._final_block.variables + vars_all += self._final_block.trainable_variables # Manually backprop through intermediate blocks for block in reversed(self._block_list): @@ -227,37 +232,39 @@ class RevNet(tf.keras.Model): assert not saved_hidden # Cleared after backprop with tf.GradientTape() as tape: - y = self._init_block(x, training=training) # Recomputing + x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed + # Running stats updated below + y = self._init_block(x, training=training) grads_all += tape.gradient( - y, self._init_block.variables, output_gradients=[dy]) - vars_all += self._init_block.variables + y, self._init_block.trainable_variables, output_gradients=[dy]) + vars_all += self._init_block.trainable_variables - return grads_all, vars_all + # Apply weight decay + grads_all = self._apply_weight_decay(grads_all, vars_all) - def train_step(self, - inputs, - labels, - optimizer, - global_step=None, - report=False): - """Train for one iteration.""" + return grads_all, vars_all, loss - grads_all, vars_all = self.compute_gradients(inputs, labels, training=True) - optimizer.apply_gradients(zip(grads_all, vars_all), global_step=global_step) + def _apply_weight_decay(self, grads, vars_): + """Update gradients to reflect weight decay.""" + # Don't decay bias + return [ + g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g + for g, v in zip(grads, vars_) + ] - if report: - logits, _ = self.call(inputs, training=True) - loss = self.compute_loss(logits, labels) + def get_moving_stats(self): + vars_and_vals = {} - return loss + def _is_moving_var(v): + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") - def eval_step(self, inputs, labels): - """Evaluate.""" + for v in filter(_is_moving_var, self.variables): + vars_and_vals[v] = v.read_value() - logits, _ = self.call(inputs, training=False) - preds = tf.cast(tf.argmax(logits, axis=1), tf.int32) - corrects = tf.cast(tf.equal(preds, labels), tf.float32) - accuracy = tf.reduce_mean(corrects) + return vars_and_vals - return accuracy + def restore_moving_stats(self, vars_and_vals): + for var_, val in six.iteritems(vars_and_vals): + var_.assign(val) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 68502ceac2360e2b9ea965743d507439a09c3e59..a5f240436a51f1f07669e06017761f003bfd9395 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -28,10 +28,19 @@ from tensorflow.python.client import device_lib tfe = tf.contrib.eager -class RevnetTest(tf.test.TestCase): +def train_one_iter(model, inputs, labels, optimizer, global_step=None): + """Train for one iteration.""" + grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + return loss + + +class RevNetTest(tf.test.TestCase): def setUp(self): - super(RevnetTest, self).setUp() + super(RevNetTest, self).setUp() + tf.set_random_seed(1) config = config_.get_hparams_imagenet_56() shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) @@ -48,7 +57,7 @@ class RevnetTest(tf.test.TestCase): del self.x del self.t del self.config - super(RevnetTest, self).tearDown() + super(RevNetTest, self).tearDown() def test_call(self): """Test `call` function.""" @@ -59,7 +68,8 @@ class RevnetTest(tf.test.TestCase): def test_compute_gradients(self): """Test `compute_gradients` function.""" - grads, vars_ = self.model.compute_gradients(inputs=self.x, labels=self.t) + grads, vars_, _ = self.model.compute_gradients( + inputs=self.x, labels=self.t, training=True) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -67,38 +77,48 @@ class RevnetTest(tf.test.TestCase): if grad is not None: self.assertEqual(grad.shape, var.shape) - def test_train_step(self): - """Test `train_step` function.""" - - logits, _ = self.model(self.x, training=True) - loss = self.model.compute_loss(logits=logits, labels=self.t) - optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) - - # Loss should be decreasing after each optimization step - for _ in range(3): - loss_ = self.model.train_step(self.x, self.t, optimizer, report=True) - self.assertTrue(loss_.numpy() <= loss.numpy()) - loss = loss_ - def test_call_defun(self): - """Test `call` function with tfe.defun apply.""" + """Test `call` function with defun.""" y, _ = tfe.defun(self.model.call)(self.x, training=False) self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) - def test_train_step_defun(self): - self.model.call = tfe.defun(self.model.call) - logits, _ = self.model(self.x, training=True) - loss = self.model.compute_loss(logits=logits, labels=self.t) - optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) - - for _ in range(3): - loss_ = self.model.train_step(self.x, self.t, optimizer, report=True) - self.assertTrue(loss_.numpy() <= loss.numpy()) - loss = loss_ + def test_compute_gradients_defun(self): + """Test `compute_gradients` function with defun.""" + compute_gradients = tfe.defun(self.model.compute_gradients) + grads, vars_, _ = compute_gradients(self.x, self.t, training=True) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + if grad is not None: + self.assertEqual(grad.shape, var.shape) - # Initialize new model, so that other tests are not affected - self.model = revnet.RevNet(config=self.config) + def test_training_graph(self): + """Test model training in graph mode.""" + + with tf.Graph().as_default(): + x = tf.random_normal( + shape=(self.config.batch_size,) + self.config.input_shape) + t = tf.random_uniform( + shape=(self.config.batch_size,), + minval=0, + maxval=self.config.n_classes, + dtype=tf.int32) + global_step = tfe.Variable(0., trainable=False) + model = revnet.RevNet(config=self.config) + grads_all, vars_all, _ = model.compute_gradients(x, t, training=True) + optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) + updates = model.get_updates_for(x) + self.assertEqual(len(updates), 192) + with tf.control_dependencies(model.get_updates_for(x)): + train_op = optimizer.apply_gradients( + zip(grads_all, vars_all), global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(1): + sess.run(train_op) # Benchmark related @@ -126,7 +146,7 @@ class MockIterator(object): return self._tensors -class RevnetBenchmark(tf.test.Benchmark): +class RevNetBenchmark(tf.test.Benchmark): """Eager and graph benchmarks for RevNet.""" def _train_batch_sizes(self): @@ -227,7 +247,7 @@ class RevnetBenchmark(tf.test.Benchmark): iterator = make_iterator((images, labels)) for _ in range(num_burn): (images, labels) = iterator.next() - model.train_step(images, labels, optimizer) + train_one_iter(model, images, labels, optimizer) if execution_mode: tfe.async_wait() self._force_device_sync() @@ -236,7 +256,7 @@ class RevnetBenchmark(tf.test.Benchmark): start = time.time() for _ in range(num_iters): (images, labels) = iterator.next() - model.train_step(images, labels, optimizer) + train_one_iter(model, images, labels, optimizer) if execution_mode: tfe.async_wait() self._force_device_sync() diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b470a41d815ce650731680065cc7341f844e3fdc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/BUILD @@ -0,0 +1,59 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# Model +py_library( + name = "config", + srcs = ["config.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "ops", + srcs = ["ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "sagan", + srcs = ["sagan.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +# Tests +cuda_py_test( + name = "ops_test", + size = "small", + srcs = ["ops_test.py"], + additional_deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "sagan_test", + size = "large", + srcs = ["sagan_test.py"], + additional_deps = [ + ":config", + ":sagan", + "//tensorflow:tensorflow_py", + ], + tags = [ + "optonly", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1967bbd867447d9deaf9a7cb3b22a38889276a50 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/config.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Configuration in format of tf.contrib.training.HParams. +Supports default 128x128 ImageNet. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +tfe = tf.contrib.eager + + +def get_hparams_imagenet(): + """Configurations to train SAGAN on 128x128 ImageNet dataset.""" + config = tf.contrib.training.HParams() + if tf.test.is_gpu_available(): + config.add_hparam("image_shape", (3, 128, 128)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (512, 4, 4)) + else: + config.add_hparam("image_shape", (128, 128, 3)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (4, 4, 512)) + + config.add_hparam("latent_dim", 128) + config.add_hparam("update_g_once_every", 1) + config.add_hparam("batch_size", 64) + config.add_hparam("d_init_filters", 32) + config.add_hparam("num_upsamples", 5) + # (512, 4, 4) -> (3, 128, 128) + return config + + +def get_hparams_mock(): + """Configurations of smaller networks for testing.""" + config = tf.contrib.training.HParams() + if tf.test.is_gpu_available(): + config.add_hparam("image_shape", (3, 16, 16)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (32, 2, 2)) + else: + config.add_hparam("image_shape", (16, 16, 3)) + config.add_hparam("data_format", "channels_last") + config.add_hparam("g_init_shape", (2, 2, 32)) + + config.add_hparam("latent_dim", 16) + config.add_hparam("update_g_once_every", 1) + config.add_hparam("batch_size", 2) + config.add_hparam("d_init_filters", 4) + config.add_hparam("num_upsamples", 3) + # (32, 2, 2) -> (3, 16, 16) + return config diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9a03cab1d12fc16baa7343f72ac58ccd39f698bc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/ops.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Auxiliary operations. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def flatten_hw(x, data_format="channels_first"): + """Flatten the input tensor across height and width dimensions.""" + if data_format == "channels_last": + x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first` + + old_shape = tf.shape(x) + new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]] + + return tf.reshape(x, new_shape) + + +def broaden_hw(x, h, w, c, data_format="channels_first"): + """Broaden dimension so that output has height and width.""" + if data_format == "channels_first": + shape = [-1, c, h, w] + else: + shape = [-1, h, w, c] + + return tf.reshape(x, shape) + + +class BroadenHW(tf.keras.layers.Layer): + """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`.""" + + def __init__(self, h, w, c, data_format="channels_first"): + super(BroadenHW, self).__init__() + self.h = h + self.w = w + self.c = c + self.data_format = data_format + + def call(self, x): + return broaden_hw( + x, h=self.h, w=self.w, c=self.c, data_format=self.data_format) + + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + if self.data_format == "channels_first": + output_shape = (input_shape[0], self.c, self.h, self.w) + else: + output_shape = (input_shape[0], self.h, self.w, self.c) + + return tf.TensorShape(output_shape) diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3454985904215b59d27fc4b76ccb4a8c2c2eff00 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py @@ -0,0 +1,59 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for auxiliary operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import ops + + +class OpsTest(tf.test.TestCase): + + def test_flatten_hw(self): + """Test `flatten_hw` function with mock object.""" + + batch_size = 1 + # Default NCHW format + if tf.test.is_gpu_available(): + x = tf.random_normal(shape=(batch_size, 3, 4, 4)) + y = ops.flatten_hw(x, data_format="channels_first") + self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) + + # NHWC format + x = tf.random_normal(shape=(batch_size, 4, 4, 3)) + y = ops.flatten_hw(x, data_format="channels_last") + self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) + + def test_broaden_hw(self): + """Test `broaden_hw` function with mock object.""" + + batch_size = 1 + # NHWC format + x = tf.random_normal(shape=[batch_size, 4 * 4 * 16]) + y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last") + self.assertEqual(y.shape, (batch_size, 4, 4, 16)) + + # Default NCHW format + if tf.test.is_gpu_available(): + y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first") + self.assertEqual(y.shape, (batch_size, 16, 4, 4)) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py new file mode 100644 index 0000000000000000000000000000000000000000..561be36c911d7145e2d4a5ed12eccd8ceb054f45 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py @@ -0,0 +1,232 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Code for main model. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import ops +tfe = tf.contrib.eager + + +class SelfAttentionModule(tf.keras.Model): + """Self-attention module composed of convolutional layers.""" + + def __init__(self, + attention_features, + original_features, + data_format="channels_first"): + """Initialize the module. + + Args: + attention_features: Number of filters for the attention computation. + original_features: Number of filters of the original Tensor. + data_format: Either 'channels_first' or 'channels_last' + """ + super(SelfAttentionModule, self).__init__() + self.data_format = data_format + # Matrix multiplication implemented as 2D Convolution + self.f = tf.keras.layers.Conv2D( + filters=attention_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.g = tf.keras.layers.Conv2D( + filters=attention_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.h = tf.keras.layers.Conv2D( + filters=original_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.scale = tfe.Variable(0., trainable=True) + + def call(self, x): + f = self.f(x) + g = self.g(x) + h = self.h(x) + + f_flatten = ops.flatten_hw(f, data_format=self.data_format) + g_flatten = ops.flatten_hw(g, data_format=self.data_format) + h_flatten = ops.flatten_hw(h, data_format=self.data_format) + + s = tf.matmul(g_flatten, f_flatten, transpose_b=True) + b = tf.nn.softmax(s, axis=-1) + o = tf.matmul(b, h_flatten) + y = self.scale * tf.reshape(o, tf.shape(x)) + x + + return y + + def compute_output_shape(self, input_shape): + return input_shape + + +class SAGAN(tf.contrib.checkpoint.Checkpointable): + """Self-attention generative adversarial network.""" + + def __init__(self, config): + """Initialize the model. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(SAGAN, self).__init__() + self.config = config + self.generator = self._construct_generator() + self.discriminator = self._construct_discriminator() + + def _construct_generator(self): + """Construct generator.""" + # TODO(lxuechen): Add spectral normalization for WGAN + axis = 1 if self.config.data_format == "channels_first" else 3 + + generator = tf.keras.Sequential() + generator.add( + tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,))) + generator.add( + tf.keras.layers.Dense( + units=np.prod(self.config.g_init_shape), activation=tf.nn.relu)) + + if self.config.data_format == "channels_first": + c, h, w = self.config.g_init_shape + else: + h, w, c = self.config.g_init_shape + + # Reshape to NHWC/NCHW + generator.add( + ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format)) + + filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)] + filters_list[-1] = 3 # Standard RGB images + + for filters in filters_list[:len(filters_list) // 2]: + generator.add( + tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=4, + strides=(2, 2), + use_bias=False, + padding="SAME", + data_format=self.config.data_format)) + generator.add(tf.keras.layers.BatchNormalization(axis=axis)) + generator.add(tf.keras.layers.Activation("relu")) + + # pylint: disable=undefined-loop-variable + generator.add( + SelfAttentionModule( + original_features=filters, + attention_features=filters // 8, + data_format=self.config.data_format)) + # pylint: enable=undefined-loop-variable + + for filters in filters_list[len(filters_list) // 2:]: + generator.add( + tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=4, + strides=(2, 2), + use_bias=False, + padding="SAME", + data_format=self.config.data_format)) + if filters == 3: + # Assume Image rescaled to [-1, 1] + generator.add(tf.keras.layers.Activation("tanh")) + else: + generator.add(tf.keras.layers.BatchNormalization(axis=axis)) + generator.add(tf.keras.layers.Activation("relu")) + + return generator + + def _construct_discriminator(self): + """Construct discriminator.""" + # TODO(lxuechen): Add spectral normalization for WGAN + discriminator = tf.keras.Sequential() + discriminator.add( + tf.keras.layers.InputLayer(input_shape=self.config.image_shape)) + + filters_list = [ + self.config.d_init_filters * 2**p + for p in range(self.config.num_upsamples) + ] + + for filters in filters_list[:(len(filters_list) + 1) // 2]: + discriminator.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=4, + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) + + # pylint: disable=undefined-loop-variable + discriminator.add( + SelfAttentionModule( + original_features=filters, + attention_features=filters // 8, + data_format=self.config.data_format)) + # pylint: enable=undefined-loop-variable + + for filters in filters_list[(len(filters_list) + 1) // 2:]: + discriminator.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=4, + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) + + discriminator.add(tf.keras.layers.Flatten()) + discriminator.add(tf.keras.layers.Dense(units=1)) + + return discriminator + + def compute_loss_and_grads(self, real_images, noise, training=True): + """Compute loss and gradients for both generator and discriminator.""" + # TODO(lxuechen): Add gradient penalty for discriminator + with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: + real_logits = self.discriminator(real_images, training=training) + + fake_images = self.generator.call(noise, training=training) + fake_logits = self.discriminator.call(fake_images) + + g_loss = self.compute_g_loss(fake_logits) + d_loss = self.compute_d_loss(fake_logits, real_logits) + + g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables) + d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables) + + return g_loss, d_loss, g_grads, d_grads + + def compute_g_loss(self, fake_logits): + return -tf.reduce_mean(fake_logits) # Hinge loss + + def compute_d_loss(self, fake_logits, real_logits): + # Hinge loss + real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits)) + fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits)) + return real_loss + fake_loss diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..18345945108111b57c5401c26b7dca0bfc8f8316 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for self-attention generative adversarial network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import config as config_ +from tensorflow.contrib.eager.python.examples.sagan import sagan +tfe = tf.contrib.eager + + +class SAGANTest(tf.test.TestCase): + + def setUp(self): + super(SAGANTest, self).setUp() + config = config_.get_hparams_mock() + self.noise_shape = (config.batch_size, config.latent_dim) + self.logits_shape = (config.batch_size, 1) + self.images_shape = (config.batch_size,) + config.image_shape + + self.model = sagan.SAGAN(config=config) + self.noise = tf.random_normal(shape=self.noise_shape) + self.real_images = tf.random_normal(shape=self.images_shape) + self.config = config + + def tearDown(self): + del self.model + del self.noise + del self.real_images + super(SAGANTest, self).tearDown() + + def test_generator_call(self): + """Test `generator.__call__` function.""" + fake_images = self.model.generator(self.noise, training=False) + self.assertEqual(fake_images.shape, self.images_shape) + + def test_generator_call_defun(self): + """Test `generator.__call__` function with defun.""" + call_ = tfe.defun(self.model.generator.__call__) + fake_images = call_(self.noise, training=False) + self.assertEqual(fake_images.shape, self.images_shape) + + def test_discriminator_call(self): + """Test `discriminator.__call__` function.""" + real_logits = self.model.discriminator(self.real_images) + self.assertEqual(real_logits.shape, self.logits_shape) + + def test_discriminator_call_defun(self): + """Test `discriminator.__call__` function with defun.""" + call_ = tfe.defun(self.model.discriminator.__call__) + real_logits = call_(self.real_images) + self.assertEqual(real_logits.shape, self.logits_shape) + + def test_compute_loss_and_grads(self): + """Test `compute_loss_and_grads` function.""" + g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads( + self.real_images, self.noise, training=False) + self.assertEqual(g_loss.shape, ()) + self.assertEqual(d_loss.shape, ()) + self.assertTrue(isinstance(g_grads, list)) + self.assertTrue(isinstance(d_grads, list)) + g_vars = self.model.generator.trainable_variables + d_vars = self.model.discriminator.trainable_variables + + self.assertEqual(len(g_grads), len(g_vars)) + self.assertEqual(len(d_grads), len(d_vars)) + + def test_compute_loss_and_grads_defun(self): + """Test `compute_loss_and_grads` function with defun.""" + compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads) + g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads( + self.real_images, self.noise, training=False) + self.assertEqual(g_loss.shape, ()) + self.assertEqual(d_loss.shape, ()) + self.assertTrue(isinstance(g_grads, list)) + self.assertTrue(isinstance(d_grads, list)) + g_vars = self.model.generator.trainable_variables + d_vars = self.model.discriminator.trainable_variables + + self.assertEqual(len(g_grads), len(g_vars)) + self.assertEqual(len(d_grads), len(d_vars)) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 2d2aba6908b168e0bf63f4706b6344cbb4ca82bd..23f33d0230b0b9fa906636a9df4e046c6873d90b 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -4,8 +4,8 @@ Eager execution is a feature that makes TensorFlow execute operations immediately: concrete values are returned, instead of creating a computational graph that is executed later. -A user guide is available: https://www.tensorflow.org/programmers_guide/eager -([source file](../../../../docs_src/programmers_guide/eager.md)) +A user guide is available: https://www.tensorflow.org/guide/eager +([source file](../../../../docs_src/guide/eager.md)) We welcome feedback through [GitHub issues](https://github.com/tensorflow/tensorflow/labels/comp:eager). diff --git a/tensorflow/contrib/eager/python/metrics.py b/tensorflow/contrib/eager/python/metrics.py index 3e3100427376ddd480b50d967cf53e7831aaefb2..04b7b1165e19612be2fa878f83effbe814fc5c46 100644 --- a/tensorflow/contrib/eager/python/metrics.py +++ b/tensorflow/contrib/eager/python/metrics.py @@ -22,5 +22,6 @@ from __future__ import print_function from tensorflow.contrib.eager.python.metrics_impl import * from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['Accuracy', 'Mean', 'Metric'] +_allowed_symbols = ['Accuracy', 'Mean', 'Metric', 'CategoricalAccuracy', + 'BinaryAccuracy', 'SparseAccuracy'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index c947ed9dcc415670a820f8a5cd9eaaf07334cfc3..efa6ba062631500bd7cd16620ebec23d15b93b62 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -345,9 +345,14 @@ class Mean(Metric): class Accuracy(Mean): - """Calculates how often `predictions` matches `labels`.""" + """Calculates how often `predictions` matches `labels`. + Attributes: + name: name of the accuracy object + dtype: data type of the tensor + """ def __init__(self, name=None, dtype=dtypes.float64): + """Inits Accuracy class with name and dtype.""" super(Accuracy, self).__init__(name=name, dtype=dtype) def call(self, labels, predictions, weights=None): @@ -377,3 +382,146 @@ class Accuracy(Mean): if weights is None: return labels, predictions return labels, predictions, weights + + +class CategoricalAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with `tf.keras.losses.categorical_crossentropy`, + `tf.nn.softmax_cross_entropy_with_logits_v2`, + `tf.losses.softmax_cross_entropy`. + + Attributes: + name: name of the accuracy object. + dtype: data type of tensor. + """ + + def __init__(self, name=None, dtype=dtypes.float64): + """Inits CategoricalAccuracy with name and dtype.""" + super(CategoricalAccuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape. + As argmax is being done here, labels and predictions type + can be different. + + Args: + labels: One-hot Tensor. + predictions: Tensor with the logits or probabilities for each example. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") + labels = math_ops.argmax(labels, axis=-1) + predictions = math_ops.argmax(predictions, axis=-1) + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(CategoricalAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights + + +class BinaryAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with `tf.keras.losses.binary_crossentropy`, + `tf.losses.sigmoid_cross_entropy`, + `tf.nn.sigmoid_cross_entropy_with_logits`. + If there is more than one label, this will become multi-label classification. + + Attributes: + name: name of the accuracy object. + threshold: Used for rounding off the predictions. + If the predictions are, + 1. probabilities then set the threshold to 0.5. + 2. logits then set the threshold to 0. + You can set the threshold appropriately, + to trade off with precision and recall. + dtype: data type of tensor. + """ + + def __init__(self, threshold, name=None, dtype=dtypes.float64): + """Inits BinaryAccuracy with name, threshold and dtype.""" + + super(BinaryAccuracy, self).__init__(name=name, dtype=dtype) + self.threshold = threshold + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape and type. + + Args: + labels: Binary Tensor(containing 0 or 1). + predictions: Tensor with probabilities or logits. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") + predictions = ops.convert_to_tensor(predictions) + predictions = predictions > self.threshold + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(BinaryAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights + + +class SparseAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with + `tf.keras.losses.sparse_categorical_crossentropy`, + `tf.nn.sparse_softmax_cross_entropy_with_logits`, + `tf.losses.sparse_softmax_cross_entropy`. + + Attributes: + name: name of the accuracy object + dtype: data type of tensor. + """ + + def __init__(self, name=None, dtype=dtypes.float64): + """Inits SparseAccuracy with name and dtype.""" + + super(SparseAccuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape except the + predictions must have one additional trailing dimension equal to the + number of classes(you want to predict). + + Type of labels and predictions can be different. + + Args: + labels: Tensor of shape (batch_size, ) containing integers + predictions: Tensor with the logits or probabilities for each example. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions)[0], + message="First axis of labels and predictions is unequal") + predictions = math_ops.argmax(predictions, axis=-1) + labels = math_ops.cast(labels, dtypes.int64) + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(SparseAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 02ee05487515b81bfae70d02c1dfdb6d816b77c7..20d938d492bf78fab852c638ba675d7ee6ed9073 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -118,6 +118,39 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testCategoricalAccuracy(self): + m = metrics.CategoricalAccuracy() + m([[1, 0, 0, 0], [0, 1, 0, 0]], + [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 1/2 correct + m([[0, 0, 0, 1]], [[0.25, 0.95, 0.25, 0.0]]) # 0/1 correct + m([[1, 0, 0, 0], [0, 1, 0, 0]], + [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct + self.assertEqual(2.0/5, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testBinaryAccuracy(self): + m = metrics.BinaryAccuracy(threshold=0) + # as threshold is 0 hence the predictions are logits + m([[0, 0, 0, 0]], + [[-4.2, 4.5, 1.2, -1.1]]) # 2/4 correct + m([[0, 1]], [[-5.3, 11.65]]) # 2/2 correct + m([[0, 1], [1, 1]], + [[-5.3, 11.65], [-10.32, 56.38]]) # 3/4 correct + self.assertEqual(7.0/10, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testSparseAccuracy(self): + m = metrics.SparseAccuracy() + m([0, 2], + [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 2/2 correct + m([1], [[0.25, 0.95, 0.25, 0.0]]) # 1/1 correct + m([0, 3], [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct + self.assertEqual(4.0/5, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + def testAccuracyDifferentShapes(self): m = metrics.Accuracy() with self.assertRaises(errors.InvalidArgumentError): @@ -173,7 +206,7 @@ class MetricsTest(test.TestCase): sess.run(accumulate, feed_dict={p: 7}) self.assertAllEqual(m.result().eval(), 7) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphAndEagerTensor(self): m = metrics.Mean() inputs = ops.convert_to_tensor([1.0, 2.0]) @@ -221,7 +254,7 @@ class MetricsTest(test.TestCase): self.assertAllEqual(m2.result().eval(), 2.0) self.assertAllEqual(m1.result().eval(), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index c92bd15b253b67a3301cd562046a4467e1bf877d..240f213c602395b8589d39c3ecd90b602ffa9848 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -126,7 +126,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual([[17.0], [34.0]], self.evaluate(result)) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( @@ -138,7 +138,7 @@ class NetworkTest(test.TestCase): self._save_modify_load_network_built(net, global_step=10) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestoreDefaultGlobalStep(self): net = MyNetwork(name="abcd") net(constant_op.constant([[2.0]])) @@ -149,7 +149,7 @@ class NetworkTest(test.TestCase): self.assertIn("abcd-4242", save_path) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkSaveAndRestoreIntoUnbuilt(self): save_dir = self.get_temp_dir() net1 = MyNetwork() @@ -166,7 +166,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkMatchesLayerVariableNames(self): zero = constant_op.constant([[0.]]) layer_one = core.Dense(1, use_bias=False) @@ -193,7 +193,7 @@ class NetworkTest(test.TestCase): self.assertEqual("two_layer_net/" + layer_two.variables[0].name, net.second.variables[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadIntoUnbuiltSharedLayer(self): class Owner(network.Network): @@ -272,7 +272,7 @@ class NetworkTest(test.TestCase): network.restore_network_checkpoint( load_into, save_path, map_func=_restore_map_func) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRestoreIntoSubNetwork(self): class Parent(network.Network): @@ -327,7 +327,7 @@ class NetworkTest(test.TestCase): # The checkpoint is incompatible. network.restore_network_checkpoint(save_into_parent, checkpoint) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCustomMapCollisionErrors(self): class Parent(network.Network): @@ -372,7 +372,7 @@ class NetworkTest(test.TestCase): network.restore_network_checkpoint( loader, checkpoint, map_func=lambda n: "foo") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) @@ -571,7 +571,7 @@ class NetworkTest(test.TestCase): expected_start="my_network_1/dense/", actual=outside_net_after.trainable_weights[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVariableScopeStripping(self): with variable_scope.variable_scope("scope1"): with variable_scope.variable_scope("scope2"): @@ -596,7 +596,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual([[42.]], self.evaluate(restore_net.variables[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerNamesRespected(self): class ParentNetwork(network.Network): @@ -677,7 +677,7 @@ class NetworkTest(test.TestCase): self.assertStartsWith(expected_start="my_network_1/dense/", actual=net2.trainable_weights[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableAnonymous(self): # The case where no explicit names are specified. We make up unique names, @@ -721,7 +721,7 @@ class NetworkTest(test.TestCase): self.assertEqual("my_network", net2.first.name) self.assertEqual("my_network_1", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicit(self): # We have explicit network names and everything is globally unique. @@ -750,7 +750,7 @@ class NetworkTest(test.TestCase): self.assertEqual("first_unique_child_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerNetworkNameInteractions(self): # Same base name as core.Dense; Networks and non-Network Layers with the @@ -801,7 +801,7 @@ class NetworkTest(test.TestCase): actual=net.trainable_weights[4].name) self.assertEqual("mixed_layer_network", net.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitCollisions(self): # We have explicit network names and they are unique within the layer @@ -831,7 +831,7 @@ class NetworkTest(test.TestCase): self.assertEqual("nonunique_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitWithAnonymousParent(self): # A parent network is instantiated multiple times with explicitly named @@ -873,7 +873,7 @@ class NetworkTest(test.TestCase): self.assertEqual("first_unique_child_name", net2.first.name) self.assertEqual("second_unique_child_name", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitSameLayerCollisions(self): # We have explicit network names and they are _not_ unique within the layer @@ -891,7 +891,7 @@ class NetworkTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "nonunique_name"): ParentNetwork() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAnonymousVariableSharing(self): # Two "owned" Networks @@ -989,7 +989,7 @@ class NetworkTest(test.TestCase): self.assertEqual("my_network", net4.first.name) self.assertEqual("my_network", net4.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRecursiveLayerRenaming(self): core.Dense(1) # Under default Layer naming, would change subsequent names. @@ -1041,7 +1041,7 @@ class NetworkTest(test.TestCase): self.assertEqual("dense", net.second.first.name) self.assertEqual("dense_1", net.second.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallInDifferentOrderThanConstruct(self): shared_network = MyNetwork() @@ -1091,7 +1091,7 @@ class NetworkTest(test.TestCase): self.assertTrue(net2.first is net1.first) self.assertEqual("my_network", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerCallInDifferentOrderThanConstruct(self): # Same idea as testCallInDifferentOrderThanConstruct, but this time with a # non-Network Layer shared between two Networks rather than a @@ -1144,7 +1144,7 @@ class NetworkTest(test.TestCase): self.assertTrue(net2.first is net1.first) self.assertEqual("dense", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerAlreadyBuilt(self): one = constant_op.constant([[1.]]) core.Dense(1, use_bias=False) # pre-built layers use global naming diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index fee9db46fa4f79d7dd613436726e8ddad51faf1c..ca6430253b67d825290b6a376ba3f29b3ae67577 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -68,6 +68,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@async_clear_error @@run_test_in_graph_and_eager_modes +@@run_all_tests_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@ -121,7 +122,7 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.tracking import Checkpointable from tensorflow.python.training.checkpointable.util import CheckpointableSaver from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index f1c60a912c8b1daa7db34f46e92bcc36ab300716..4bb90cf81bc32723e24a220e45c43c1f9b3f1980 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -53,6 +53,18 @@ class DNNEstimator(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNEstimator( head=tf.contrib.estimator.multi_label_head(n_classes=3), @@ -115,8 +127,9 @@ class DNNEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index ccaf1128bf23af734f7a5722a4dd8c1f0304fab7..894a2954987a4af760d3c08fc6f30405010150c5 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -53,12 +53,19 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_hidden_units=[1000, 500, 100], dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -116,12 +123,16 @@ class DNNLinearCombinedEstimator(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 9594e5132fd20dadea118fd1dd6768feb7fd7fff..c9d86ef4ab89950b0c7b0414ba60d9e0a1cbe476 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -534,7 +534,8 @@ def multi_label_head(n_classes, * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` - must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. + must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a + multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index b2b57fa06ba818d4455871fe57dde5ce287b39a2..7b884402d4650636bc9fe053994246aabb9c312d 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -568,6 +568,33 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_label_vocabulary_with_multi_hot_input(self): + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, label_vocabulary=['class0', 'class1']) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels_multi_hot, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_thresholds(self): n_classes = 2 thresholds = [0.25, 0.5, 0.75] diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py index 3bf4abe83d54504d55de73b63f369cceaf149dd2..b960b16f1ba6b1bf8046c922e21ac1ed136c599e 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear.py +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -39,6 +39,18 @@ class LinearEstimator(estimator.Estimator): feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator using the FTRL optimizer with regularization. estimator = LinearEstimator( head=tf.contrib.estimator.multi_label_head(n_classes=3), @@ -99,8 +111,9 @@ class LinearEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. """ diff --git a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc index bb9b835889b1b5e36d6f470b51834d4c6bb3d493..7fcae5ad8e1536530e2d039e1d14df4e192c4fa3 100644 --- a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc +++ b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc @@ -62,10 +62,11 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { public: explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->MatchSignature( - {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL}, - {DT_FLOAT, DT_FLOAT})); + OP_REQUIRES_OK(context, + context->MatchSignature( + {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, + DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL}, + {DT_FLOAT, DT_FLOAT})); } void Compute(OpKernelContext* context) override { @@ -75,8 +76,9 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const Tensor& input_weights = context->input(3); const Tensor& input_indices = context->input(4); const Tensor& input_values = context->input(5); - const Tensor& input_block_size = context->input(6); - const Tensor& input_is_transpose = context->input(7); + const Tensor& entry_weights = context->input(6); + const Tensor& input_block_size = context->input(7); + const Tensor& input_is_transpose = context->input(8); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()), InvalidArgument("Input factors should be a matrix.")); @@ -89,13 +91,33 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { InvalidArgument("Input input_weights should be a vector.")); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), InvalidArgument("Input input_indices should be a matrix.")); + OP_REQUIRES( + context, input_indices.dim_size(1) == 2, + InvalidArgument("Input input_indices should have shape (?, 2).")); OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), InvalidArgument("Input input_values should be a vector")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(entry_weights.shape()), + InvalidArgument("Input entry_weights should be a vector")); + OP_REQUIRES(context, input_indices.dim_size(0) == input_values.dim_size(0), + InvalidArgument("Input input_values' length should match the " + "first dimension of Input input_indices ")); OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()), InvalidArgument("Input input_block_size should be a scalar.")); OP_REQUIRES( context, TensorShapeUtils::IsScalar(input_is_transpose.shape()), InvalidArgument("Input input_is_transpose should be a scalar.")); + OP_REQUIRES( + context, + ((input_weights.dim_size(0) > 0 && + factor_weights.dim_size(0) == factors.dim_size(0) && + entry_weights.dim_size(0) == 0) || + (input_weights.dim_size(0) == 0 && factor_weights.dim_size(0) == 0 && + entry_weights.dim_size(0) == input_indices.dim_size(0))), + InvalidArgument("To specify the weights for observed entries, either " + "(1) entry_weights must be set or (2) input_weights " + "and factor_weights must be set, but not both.")); + // TODO(yifanchen): Deprecate the support of input_weights and + // factor_weights. const int64 factor_dim = factors.dim_size(1); const int64 factors_size = factors.dim_size(0); @@ -105,6 +127,7 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const auto& input_weights_vec = input_weights.vec(); const float w_0 = unobserved_weights.scalar()(); const auto& input_values_vec = input_values.vec(); + const auto& entry_weights_vec = entry_weights.vec(); ConstEigenMatrixFloatMap factors_mat(factors.matrix().data(), factor_dim, factors_size); @@ -134,6 +157,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { return is_transpose ? indices_mat(0, i) : indices_mat(1, i); }; + const bool use_entry_weights = entry_weights_vec.size() > 0; + // TODO(rmlarsen): In principle, we should be using the SparseTensor class // and machinery for iterating over groups, but the fact that class // SparseTensor makes a complete copy of the matrix makes me reluctant to @@ -195,6 +220,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { // map using the hash of the thread id as the key. // // TODO(jpoulson): Switch to try_emplace once C++17 is supported + // TODO(b/72952120): Check whether the 3 lock-unlock pairs can be + // consolidated into just one. map_mutex.lock(); const auto key_count = factor_batch_map.count(id_hash); map_mutex.unlock(); @@ -213,6 +240,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { CHECK_LE(shard.second, perm.size()); CHECK_LE(shard.first, shard.second); const int64 input_index = get_input_index(perm[shard.first]); + const float input_weight = + use_entry_weights ? 1.0 : input_weights_vec(input_index); // Accumulate the rhs and lhs terms in the normal equations // for the non-zero elements in the row or column of the sparse matrix // corresponding to input_index. @@ -228,7 +257,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const int64 factor_index = get_factor_index(i); const float input_value = input_values_vec(i); const float weight = - input_weights_vec(input_index) * factor_weights_vec(factor_index); + use_entry_weights ? entry_weights_vec(i) + : input_weight * factor_weights_vec(factor_index); CHECK_GE(weight, 0); factor_batch.col(num_batched) = factors_mat.col(factor_index) * std::sqrt(weight); diff --git a/tensorflow/contrib/factorization/ops/factorization_ops.cc b/tensorflow/contrib/factorization/ops/factorization_ops.cc index 11ea36946e92769cd6901eb998a20148250ef7ce..1d31bd38c824f24e9a70c0f69da129f5ddc18985 100644 --- a/tensorflow/contrib/factorization/ops/factorization_ops.cc +++ b/tensorflow/contrib/factorization/ops/factorization_ops.cc @@ -25,20 +25,33 @@ REGISTER_OP("WALSComputePartialLhsAndRhs") .Input("input_weights: float32") .Input("input_indices: int64") .Input("input_values: float32") + .Input("entry_weights: float32") .Input("input_block_size: int64") .Input("input_is_transpose: bool") .Output("partial_lhs: float32") .Output("partial_rhs: float32") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"( -Computes the partial left-hand side and right-hand side of WALS update. +Computes the partial left-hand side and right-hand side of WALS update. For +observed entry input_indices[i]=[m, n] with value input_values[i]=v, the weight +should be specified either through (1) entry_weights[i] or (2) through +input_weights[m] * factor_weights[n] (if input_is_transpose is false) or +input_weights[n] * factor_weights[m] (if input_is_transpose is true). Note it is +not allowed to have both (1) and (2) specified at the same time: when one +approach is used, the input tensors related to the other approach must be kept +completely empty. factors: Matrix of size m * k. -factor_weights: Vector of size m. Corresponds to column weights +factor_weights: Vector of size m. Corresponds to column weights. Should be empty + if entry_weights is used. unobserved_weights: Scalar. Weight for unobserved input entries. -input_weights: Vector of size n. Corresponds to row weights. +input_weights: Vector of size n. Corresponds to row weights. Should be empty if + entry_weights is used. input_indices: Indices for the input SparseTensor. input_values: Values for the input SparseTensor. +entry_weights: If not empty, this must be same length as input_vaues and is used + as the per-entry non-zero weight. If this is used, input_weights and + factor_weights must be empty. input_block_size: Scalar. Number of rows spanned by input. input_is_transpose: If true, logically transposes the input for processing. partial_lhs: 3-D tensor with size input_block_size x k x k. diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py index ba30fd997700f461b6afffa13cf371c598d3332e..6c2f1d46084d701beac1e3a99e3ad66bae57eda5 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py @@ -55,7 +55,41 @@ class WalsSolverOpsTest(test.TestCase): rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( self._column_factors, self._column_weights, self._unobserved_weights, self._row_weights, sparse_block.indices, sparse_block.values, - sparse_block.dense_shape[0], False) + [], + input_block_size=sparse_block.dense_shape[0], + input_is_transpose=False) + self.assertAllClose(lhs_tensor.eval(), [[ + [0.014800, 0.017000, 0.019200], + [0.017000, 0.019600, 0.022200], + [0.019200, 0.022200, 0.025200], + ], [ + [0.0064000, 0.0080000, 0.0096000], + [0.0080000, 0.0100000, 0.0120000], + [0.0096000, 0.0120000, 0.0144000], + ], [ + [0.0099000, 0.0126000, 0.0153000], + [0.0126000, 0.0162000, 0.0198000], + [0.0153000, 0.0198000, 0.0243000], + ], [ + [0.058800, 0.067200, 0.075600], + [0.067200, 0.076800, 0.086400], + [0.075600, 0.086400, 0.097200], + ]]) + self.assertAllClose(rhs_matrix.eval(), [[0.019300, 0.023000, 0.026700], + [0.061600, 0.077000, 0.092400], + [0.160400, 0.220000, 0.279600], + [0.492800, 0.563200, 0.633600]]) + + def testWalsSolverLhsEntryWeights(self): + sparse_block = SparseBlock3x3() + with self.test_session(): + [lhs_tensor, + rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( + self._column_factors, [], self._unobserved_weights, + [], sparse_block.indices, sparse_block.values, + [0.01, 0.03, 0.04, 0.03, 0.06, 0.12], + input_block_size=sparse_block.dense_shape[0], + input_is_transpose=False) self.assertAllClose(lhs_tensor.eval(), [[ [0.014800, 0.017000, 0.019200], [0.017000, 0.019600, 0.022200], diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 8f73274c2a0ebbdc41ce6a647a8a5650694c9a23..7ab70fbcfd7324961b61526a08daab7e393630e9 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -943,6 +943,7 @@ class WALSModel(object): row_weights_slice, new_sp_input.indices, new_sp_input.values, + [], num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index b588f75efe9d0bbf8213a89978a627c0a0ccf554..05bcdac2caa77062f9a8a44a948d2897b439ea1f 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -95,7 +95,7 @@ def sequence_input_layer( Raises: ValueError: If any of the `feature_columns` is the wrong type. """ - feature_columns = fc._clean_feature_columns(feature_columns) + feature_columns = fc._normalize_feature_columns(feature_columns) for c in feature_columns: if not isinstance(c, fc._SequenceDenseColumn): raise ValueError( diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index df7d7e9dae80722569efccbc9cc0d1b75e90cf03..34fd5018af125335845540dedfdffc984ba02313 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging class CriticalSectionTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateCriticalSection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") @@ -53,7 +53,7 @@ class CriticalSectionTest(test.TestCase): self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCriticalSectionWithControlFlow(self): for outer_cond in [False, True]: for inner_cond in [False, True]: @@ -109,7 +109,7 @@ class CriticalSectionTest(test.TestCase): with self.assertRaisesOpError("Error"): self.evaluate(r) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateCriticalSectionFnReturnsOp(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") @@ -332,7 +332,7 @@ class CriticalSectionTest(test.TestCase): self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInsideFunction(self): cs = critical_section_ops.CriticalSection() v = resource_variable_ops.ResourceVariable(1) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index a955e21b72e765f751318c7927f9644481fe7933..4d62ac65ff619f98a18387058fdc8a0eade0d8f8 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -21,8 +21,6 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -35,13 +33,6 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging -def NoMemoryOptimizationConfig(): - config = config_pb2.ConfigProto() - config.graph_options.rewrite_options.memory_optimization = ( - rewriter_config_pb2.RewriterConfig.OFF) - return config - - def GetShrunkInceptionShapes(shrink=10): """Iterator for smaller versions of convolution shapes in 2015 Inception. @@ -202,8 +193,7 @@ class FusedConv2DBiasActivationTest(test.TestCase): # This is to guarantee that there is always negative values after # bias add so that we can test whether relu works correctly. x3 = bias - # TODO(b/79323979): re-enable memory optimization after this bug is fixed. - with self.test_session(use_gpu=True, config=NoMemoryOptimizationConfig()): + with self.test_session(use_gpu=True): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) fused_t2 = t2 @@ -251,9 +241,7 @@ class FusedConv2DBiasActivationTest(test.TestCase): x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) def _SetupVal(data_format, use_gpu): - # TODO(b/79323979): re-enable memory optimization after this bug is fixed. - with self.test_session( - use_gpu=use_gpu, config=NoMemoryOptimizationConfig()): + with self.test_session(use_gpu=use_gpu): t1 = constant_op.constant(x1, shape=tensor_in_sizes) t2 = constant_op.constant(x2, shape=filter_in_sizes) t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) @@ -877,9 +865,7 @@ class FusedConvInt8Tests(test.TestCase): conv_input_scale, conv_input, kernel, padding_type, strides, side_input_scale, side_input, biases) - # TODO(b/79323979): re-enable memory optimization after this bug is fixed. - with self.test_session( - use_gpu=True, config=NoMemoryOptimizationConfig()) as sess: + with self.test_session(use_gpu=True) as sess: actual_y, expected_y = sess.run([actual, expected]) tf_logging.info("actual_y = ", actual_y) tf_logging.info("expected_y = ", expected_y) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 5b5557bd8f12b4d42e508f185cb8561eaebea84e..d1441e1eb2aae0fb7d1771110f969bf727ebbb14 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -103,9 +103,20 @@ class GANHead(head._Head): # pylint: disable=protected-access name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + + if not callable(generator_loss_fn): + raise TypeError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise TypeError('discriminator_loss_fn must be callable.') + if not use_loss_summaries in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') + if name is not None and not isinstance(name, str): + raise TypeError('name must be string.') + if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() - # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index c2e32da133b32c8fe169302668031af8bace2c22..022e17d13963a14f81d76e683d13060d1f3f8a7e 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -35,6 +35,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; template struct FillProjectiveTransform; template struct FillProjectiveTransform; template struct FillProjectiveTransform; +template struct FillProjectiveTransform; template struct FillProjectiveTransform; template struct FillProjectiveTransform; @@ -99,6 +100,7 @@ class ImageProjectiveTransform : public OpKernel { TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); TF_CALL_int64(REGISTER); +TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index ad501330617be89c87a0e94ab6e8773a6e1eecf6..209aa24548443bb10c13cd506b8c93c23cfff4a4 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -21,6 +21,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" @@ -58,6 +59,11 @@ class ProjectiveGenerator { ? transforms_.data() : &transforms_.data()[transforms_.dimension(1) * coords[0]]; float projection = transform[6] * output_x + transform[7] * output_y + 1.f; + if (projection == 0) { + // Return the fill value (0) for infinite coordinates, + // which are outside the input image + return T(0); + } const float input_x = (transform[0] * output_x + transform[1] * output_y + transform[2]) / projection; @@ -105,21 +111,21 @@ class ProjectiveGenerator { // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor) const float value_yfloor = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil) const float value_yceil = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor) // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil) return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil); diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index ebdcaea7abae2a967786831b62b331897aa3f6a3..e59f1bf8443732a4b84fe7461439e3d0ee7dd158 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -29,7 +29,7 @@ using shape_inference::ShapeHandle; REGISTER_OP("ImageProjectiveTransform") .Input("images: dtype") .Input("transforms: float32") - .Attr("dtype: {uint8, int32, int64, float32, float64}") + .Attr("dtype: {uint8, int32, int64, float16, float32, float64}") .Attr("interpolation: string") .Output("transformed_images: dtype") .SetShapeFn([](InferenceContext* c) { diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index b50177ae5651fbc15f292e11031411c2074357ec..62a22dcf3411fb160b3c432bbdd67303697f7262 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -30,7 +30,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest _DTYPES = set( - [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) + [dtypes.uint8, dtypes.int32, dtypes.int64, + dtypes.float16, dtypes.float32, dtypes.float64]) class ImageOpsTest(test_util.TensorFlowTestCase): @@ -127,6 +128,23 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [0, 1, 0, 1], [0, 1, 1, 1]]) + def test_extreme_projective_transform(self): + for dtype in _DTYPES: + with self.test_session(): + image = constant_op.constant( + [[1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1]], dtype=dtype) + transformation = constant_op.constant([1, 0, 0, 0, 1, 0, -1, 0], + dtypes.float32) + image_transformed = image_ops.transform(image, transformation) + self.assertAllEqual(image_transformed.eval(), + [[1, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0]]) + def test_bilinear(self): with self.test_session(): image = constant_op.constant( diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index cd984c80543886be1f682933e2e003bd3374e425..86b0ffe9a0f2236d5ac7d5f846e7b5d2615c9b09 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -33,7 +33,8 @@ _image_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_image_ops.so")) _IMAGE_DTYPES = set( - [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) + [dtypes.uint8, dtypes.int32, dtypes.int64, + dtypes.float16, dtypes.float32, dtypes.float64]) ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index 938c881fcbe18623fa18c21c112375f9914f887b..3327a9f9a613bfb56e6a25af0fe1c0ca18609035 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,10 +20,10 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras.engine import Input -from tensorflow.python.keras.engine import InputLayer -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index b6d63c9640611abdda65f1205f544ee505dae1f0..beeabd6b65631cad88efd10d5faee1917e162e41 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2664,6 +2664,7 @@ def separable_convolution2d( normalizer_fn=None, normalizer_params=None, weights_initializer=initializers.xavier_initializer(), + pointwise_initializer=None, weights_regularizer=None, biases_initializer=init_ops.zeros_initializer(), biases_regularizer=None, @@ -2705,7 +2706,9 @@ def separable_convolution2d( `biases_regularizer` are ignored and `biases` are not created nor added. default set to None for no normalizer function normalizer_params: Normalization function parameters. - weights_initializer: An initializer for the weights. + weights_initializer: An initializer for the depthwise weights. + pointwise_initializer: An initializer for the pointwise weights. + default set to None, means use weights_initializer. weights_regularizer: Optional regularizer for the weights. biases_initializer: An initializer for the biases. If None skip biases. biases_regularizer: Optional regularizer for the biases. @@ -2737,6 +2740,9 @@ def separable_convolution2d( custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) + if pointwise_initializer is None: + pointwise_initializer = weights_initializer + df = ('channels_first' if data_format and data_format.startswith('NC') else 'channels_last') if num_outputs is not None: @@ -2752,7 +2758,7 @@ def separable_convolution2d( depth_multiplier=depth_multiplier, use_bias=not normalizer_fn and biases_initializer, depthwise_initializer=weights_initializer, - pointwise_initializer=weights_initializer, + pointwise_initializer=pointwise_initializer, bias_initializer=biases_initializer, depthwise_regularizer=weights_regularizer, pointwise_regularizer=weights_regularizer, diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 9c804d27854b8004d34c65691b48ca2b0d3bbf7c..8c17c65fcc0dbd58e2b3e9042a983e400cd6c2b9 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -184,6 +184,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 612813caee880f3f7291ee9850f7d8f842d598a6..5543acc1f5dabaa8a54ec4d1f2027bc66a00f6db 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -214,6 +214,7 @@ def generated_test_models(): "global_batch_norm", "greater", "greater_equal", + "sum", "l2norm", "l2_pool", "less", @@ -232,11 +233,14 @@ def generated_test_models(): "pad", "padv2", # "prelu", + "pow", "relu", "relu1", "relu6", "reshape", "resize_bilinear", + "rsqrt", + "shape", "sigmoid", "sin", "slice", @@ -245,6 +249,7 @@ def generated_test_models(): "space_to_depth", "sparse_to_dense", "split", + "sqrt", "squeeze", "strided_slice", "strided_slice_1d_exhaustive", diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index 9f398f4a9f3dcafd7bd49fd5d95e9991b8b36b75..e9531aef19f04adf719156aa3e874dc5ce6e2b04 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -19,22 +19,23 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../.." -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a +# Build library for supported architectures and packs them in a fat binary. +make_library() { + for arch in x86_64 i386 armv7 armv7s arm64 + do + make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \ + -j 8 \ + $SCRIPT_DIR/gen/lib/ios_${arch}/${1} + done + lipo \ + tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \ + tensorflow/contrib/lite/gen/lib/ios_i386/${1} \ + tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \ + tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \ + tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \ + -create \ + -output tensorflow/contrib/lite/gen/lib/${1} +} -lipo \ -tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \ --create \ --output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a +make_library libtensorflow-lite.a +make_library benchmark-lib.a diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index c1cc4476fbd45fa6b3f5b3a1ed2cba39cc2ad54b..cda889bf502a535eac4249bbae645359cdb2135d 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -92,8 +92,17 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + typedef struct { + // Parameters for FullyConnected version 1 or above. TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; } TfLiteFullyConnectedParams; typedef enum { @@ -215,7 +224,7 @@ typedef struct { typedef struct { bool keep_dims; -} TfLiteMeanParams; +} TfLiteReducerParams; typedef struct { int num_splits; @@ -250,6 +259,10 @@ typedef struct { bool validate_indices; } TfLiteSparseToDenseParams; +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index aef9a92883f18dabfc36058507d739856c3c2af7..a44e9182302d19acd1e1c183ed388531eec11d93 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -99,6 +99,11 @@ typedef enum { kTfLiteBuiltinEqual = 71, kTfLiteBuiltinNotEqual = 72, kTfLiteBuiltinLog = 73, + kTfLiteBuiltinSum = 74, + kTfLiteBuiltinSqrt = 75, + kTfLiteBuiltinRsqrt = 76, + kTfLiteBuiltinShape = 77, + kTfLiteBuiltinPow = 78, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 15a37de9dc665ff147b7094a61a5afab701932ce..1265c4cba9064cc5aba9af81415f857ad00f6d99 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -139,6 +139,7 @@ typedef enum { kTfLiteString = 5, kTfLiteBool = 6, kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -159,6 +160,7 @@ typedef union { uint8_t* uint8; bool* b; int16_t* i16; + _Complex float* c64; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -374,6 +376,14 @@ typedef struct _TfLiteRegistration { // Returns kTfLiteOk on success. TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + // profiling_string is called during summarization of profiling information + // in order to group executions together. Providing a value here will cause a + // given op to appear multiple times is the profiling report. This is + // particularly useful for custom ops that can perform significantly + // different calculations depending on their `user-data`. + const char* (*profiling_string)(const TfLiteContext* context, + const TfLiteNode* node); + // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like // NN API. diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index 0731d14419d2dec2ea5efa48ef5d4b7728af635f..fd798c209e5112235cf6e351e231d4096006a8b0 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -26,6 +26,10 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#ifdef __ANDROID__ +#include +#endif + namespace tflite { namespace { @@ -37,6 +41,32 @@ namespace { return kTfLiteError; \ } +namespace { +int32_t GetAndroidSdkVersion() { +#ifdef __ANDROID__ + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher then expected; + return std::numeric_limits::max(); + } + } + return atoi(sdkVersion); + } +#endif // __ANDROID__ + return 0; +} + +constexpr int32_t kMinSdkVersionForNNAPI = 27; +constexpr int32_t kMinSdkVersionForNNAPI11 = 28; +static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); + +} // namespace + // RAII NN API Model Destructor for use with std::unique_ptr struct NNFreeModel { void operator()(ANeuralNetworksModel* model) { @@ -71,7 +101,7 @@ class OperandMapping { // Add a new mapping from `tflite_index` and return the NN API tensor index. int add_new_ann_tensor_index(int tflite_index) { if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { - lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1); } int new_tensor_index = next_ann_tensor_index_++; lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; @@ -98,14 +128,28 @@ class NNAPIOpBuilder { operand_mapping_(tensor_mapping), nn_model_(nn_model) {} - TfLiteStatus AddScalarInt32Operand(int value) { - ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; - CHECK_NN(context_, - ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - int ann_operand = operand_mapping_->add_new_non_tensor_operand(); - CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( - nn_model_, ann_operand, &value, sizeof(int32_t))); - augmented_inputs_.push_back(ann_operand); + TfLiteStatus AddScalarInt32Operand(int32_t value) { + return AddScalarOperand(value, ANEURALNETWORKS_INT32); + } + + TfLiteStatus AddScalarFloat32Operand(float value) { + return AddScalarOperand(value, ANEURALNETWORKS_FLOAT32); + } + + TfLiteStatus AddVectorInt32Operand(const int32_t* values, + uint32_t num_values) { + return AddVectorOperand(values, num_values, + ANEURALNETWORKS_TENSOR_INT32); + } + + TfLiteStatus AddPoolingParams(void* data) { + auto builtin = reinterpret_cast(data); + AddScalarInt32Operand(builtin->padding); + AddScalarInt32Operand(builtin->stride_width); + AddScalarInt32Operand(builtin->stride_height); + AddScalarInt32Operand(builtin->filter_width); + AddScalarInt32Operand(builtin->filter_height); + AddScalarInt32Operand(builtin->activation); return kTfLiteOk; } @@ -149,7 +193,6 @@ class NNAPIOpBuilder { return kTfLiteOk; case kTfLiteFloat32: nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; - scale = 0.f; break; case kTfLiteUInt8: nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; @@ -158,8 +201,8 @@ class NNAPIOpBuilder { break; case kTfLiteInt32: nn_type = ANEURALNETWORKS_TENSOR_INT32; - scale = 0.f; - zeroPoint = 0; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; break; default: context_->ReportError(context_, "Logic error in NN API Delegate.\n"); @@ -192,12 +235,39 @@ class NNAPIOpBuilder { augmented_inputs_.data(), static_cast(augmented_outputs_.size()), augmented_outputs_.data())); - augmented_outputs_.clear(); + augmented_inputs_.clear(); augmented_outputs_.clear(); return kTfLiteOk; } private: + template + TfLiteStatus AddScalarOperand(T value, int32_t nn_type) { + ANeuralNetworksOperandType operand_type{.type = nn_type}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(T))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + template + TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values, + int32_t nn_type) { + ANeuralNetworksOperandType operand_type{ + .type = nn_type, .dimensionCount = 1, .dimensions = &num_values}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, + ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, values, sizeof(T) * num_values)); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + // TfLiteContext for error handling. Must be named context for macros to // work. TfLiteContext* context_; @@ -227,29 +297,161 @@ class NNAPIDelegateKernel { // Return a function that knows how to translate a node into its operands // when called. You can use this function to see if a node is supported // (i.e. that MappingFn is not nullptr). - MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + MappingFn Map(TfLiteContext* context, int builtin_code, int version, + TfLiteNode* node) { switch (builtin_code) { case kTfLiteBuiltinAdd: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { - auto builtin = reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_ADD; - }; + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMul: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_MUL; + }; + } else { + return nullptr; + } break; case kTfLiteBuiltinAveragePool2d: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMaxPool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_MAX_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinL2Pool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_L2_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConv2d: + if (version == 1) { auto builtin = - reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->padding); - builder->AddScalarInt32Operand(builtin->stride_width); - builder->AddScalarInt32Operand(builtin->stride_height); - builder->AddScalarInt32Operand(builtin->filter_width); - builder->AddScalarInt32Operand(builtin->filter_height); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_AVERAGE_POOL_2D; - }; + reinterpret_cast(node->builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node->inputs->size != 3) { + // NNAPI does not support dilated Conv2D. + return nullptr; + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDepthwiseConv2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->depth_multiplier); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DEPTHWISE_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFullyConnected: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_FULLY_CONNECTED; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSoftmax: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_SOFTMAX; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReshape: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RESHAPE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSqueeze: + // Squeeze requires NNAPI1.1. + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + // Note that we add the squeeze dimensions even if the dimensions + // were unspecified (empty), as NNAPI requires the operand. + builder->AddVectorInt32Operand( + builtin->squeeze_dims, + static_cast(builtin->num_squeeze_dims)); + return ANEURALNETWORKS_SQUEEZE; + }; + } else { + return nullptr; + } break; default: return nullptr; @@ -292,10 +494,14 @@ class NNAPIDelegateKernel { int relative_input_index = 0; for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { TfLiteTensor* tensor = &context->tensors[absolute_input_index]; - CHECK_NN(context, ANeuralNetworksExecution_setInput( - execution, relative_input_index, nullptr, - tensor->data.raw, tensor->bytes)); - relative_input_index++; + // TODO(miaowang): make sure the delegation works with dequantized weights + // as intermediate tensors. + if (tensor->allocation_type != kTfLiteMmapRo) { + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } } // Set the output tensor buffers. @@ -345,8 +551,8 @@ class NNAPIDelegateKernel { TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); } // Get op type and operands - int nn_op_type = - Map(context, reg->builtin_code, node)(context, &builder, node); + int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( + context, &builder, node); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); @@ -368,8 +574,12 @@ class NNAPIDelegateKernel { std::vector outputs; outputs.reserve(output_tensors->size); // Make the TensorFlow lite inputs and outputs to ann_indices. - for (int i : TfLiteIntArrayView(input_tensors)) - inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(input_tensors)) { + // Constant tensors are not NNAPI inputs. + if (context->tensors[i].allocation_type != kTfLiteMmapRo) { + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + } + } for (int i : TfLiteIntArrayView(output_tensors)) outputs.push_back(operand_mapping_.lite_index_to_ann(i)); // Tell ANN to declare inputs/outputs @@ -392,7 +602,9 @@ TfLiteDelegate* NnApiDelegate() { .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { // Do not check nodes_ if NN API is unavailable. - if (!NNAPIExists()) return kTfLiteOk; + if (kAndroidSdkVersion < kMinSdkVersionForNNAPI || !NNAPIExists()) { + return kTfLiteOk; + } std::vector supported_nodes(1); // We don't care about all nodes_, we only care about ones in the @@ -400,6 +612,7 @@ TfLiteDelegate* NnApiDelegate() { TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); int total_supported_nodes = 0; + // Check for every node if it is supported // TODO(b/80625235): Fix this to do more careful checking of versioning. for (int node_index : TfLiteIntArrayView(plan)) { @@ -408,7 +621,8 @@ TfLiteDelegate* NnApiDelegate() { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); NNAPIDelegateKernel dummy_kernel; - if (dummy_kernel.Map(context, registration->builtin_code, node)) { + if (dummy_kernel.Map(context, registration->builtin_code, + registration->version, node)) { supported_nodes.push_back(node_index); } total_supported_nodes += 1; diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index ff2e721423f07889f36746a2889afcc3369f28fc..aad10c9ce730a2e90481a123a1e3e323cfb2bd42 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -21,8 +21,12 @@ limitations under the License. namespace tflite { namespace { +using ::testing::ElementsAre; using ::testing::ElementsAreArray; +// TODO(b/110368244): figure out how to share the existing tests in kernels/ but +// with the delegation on. Also, add more unit tests to improve code coverage. + class FloatAddOpModel : public SingleOpModel { public: FloatAddOpModel(const TensorData& input1, const TensorData& input2, @@ -72,6 +76,596 @@ TEST(NNAPIDelegate, AddWithRelu) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); } +class FloatMulOpModel : public SingleOpModel { + public: + FloatMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, MulWithNoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +class FloatPoolingOpModel : public SingleOpModel { + public: + FloatPoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, AveragePoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(NNAPIDelegate, MaxPoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(NNAPIDelegate, L2PoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +class BaseConvolutionOpModel : public SingleOpModel { + public: + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE, + int dilation_width_factor = 1, int dilation_height_factor = 1) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions( + builder_, padding, stride_width, stride_height, activation, + dilation_width_factor, dilation_height_factor) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(NNAPIDelegate, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(NNAPIDelegate, Conv2DWithNoActivation) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +class DepthwiseConvolutionOpModel : public SingleOpModel { + public: + DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class FloatFullyConnectedOpModel : public SingleOpModel { + public: + FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, FullyConnectedSimpleTest) { + FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(NNAPIDelegate, SoftmaxSimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + new_shape_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape, {static_cast(new_shape.size())}}); + PopulateTensor(new_shape_, new_shape); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, ReshapeSimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +class SqueezeOpModel : public SingleOpModel { + public: + SqueezeOpModel(const TensorData& input, const TensorData& output, + std::initializer_list axis) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp( + BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions, + CreateSqueezeOptions(builder_, builder_.CreateVector(axis)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, SqueezeSimpleTest) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}}, + {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(NNAPIDelegate, SqueezeWithAxisTest) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}}, + {2}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 57000072561303e8457f61b1ebe95d382fc01f10..dd2cd173246719976d7cd6e52d65f63125b5b2db 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow camera demo app for Android. +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -24,28 +26,28 @@ cc_library( android_binary( name = "tflite_demo", srcs = glob([ - "src/**/*.java", + "app/src/main/java/**/*.java", ]), # Package assets from assets dir as well as all model targets. # Remove undesired models (and corresponding Activities in source) # to reduce APK size. assets = [ - "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", - "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", - "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt", - "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt", ], assets_dir = "", custom_package = "org.tensorflow.lite.demo", inline_constants = 1, - manifest = "AndroidManifest.xml", + manifest = "app/src/main/AndroidManifest.xml", nocompress_extensions = [ ".tflite", ], - resource_files = glob(["res/**"]), + resource_files = glob(["app/src/main/res/**"]), tags = [ "manual", "notap", @@ -55,31 +57,3 @@ android_binary( "//tensorflow/contrib/lite/java:tensorflowlite", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "bin/**", - "gen/**", - "gradleBuild/**", - "libs/**", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -filegroup( - name = "java_files", - srcs = glob(["src/**/*.java"]), -) - -filegroup( - name = "resource_files", - srcs = glob(["res/**"]), -) - -exports_files(["AndroidManifest.xml"]) diff --git a/tensorflow/contrib/lite/examples/android/android.iml b/tensorflow/contrib/lite/examples/android/android.iml new file mode 100644 index 0000000000000000000000000000000000000000..f0a5ac2bf4cdfb7c98f5704310fbf2f16e9065a2 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/android.iml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8e0a98ed63f99b7477cdb2f851a19cd31b45f314 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/build.gradle @@ -0,0 +1,60 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion '26.0.2' + defaultConfig { + applicationId "org.tensorflow.lite.demo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +// import DownloadModels task +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +apply from: "download-models.gradle" + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/contrib/lite/examples/android/app/download-models.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8e65dc076f2a8daaddf01ceab6796b8ed1127af3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/download-models.gradle @@ -0,0 +1,73 @@ +/* + * download-models.gradle + * Downloads model files from ${MODEL_URL} into application's asset folder + * Input: + * project.ext.TMP_DIR: absolute path to hold downloaded zip files + * project.ext.ASSET_DIR: absolute path to save unzipped model files + * Output: + * 3 model files will be downloaded into given folder of ext.ASSET_DIR + */ +// hard coded model files +// LINT.IfChange + +def models = ['conv_actions_tflite.zip', + 'mobilenet_ssd_tflite_v1.zip', + 'mobilenet_v1_224_android_quant_2017_11_08.zip'] +// LINT.ThenChange(//tensorflow/examples/android/BUILD) + +// Root URL for model archives +def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'de.undercouch:gradle-download-task:3.2.0' + } +} + +import de.undercouch.gradle.tasks.download.Download +task downloadFile(type: Download){ + for (f in models) { + def modelUrl = MODEL_URL + "/" + f + println "Downloading ${f} from ${modelUrl}" + src modelUrl + } + + dest new File(project.ext.TMP_DIR) + overwrite true +} + +task extractModels(type: Copy) { + for (f in models) { + def localFile = f.split("/")[-1] + from zipTree(project.ext.TMP_DIR + '/' + localFile) + } + + into file(project.ext.ASSET_DIR) + fileMode 0644 + exclude '**/LICENSE' + + def needDownload = false + for (f in models) { + def localFile = f.split("/")[-1] + if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) { + needDownload = true + } + } + + if (needDownload) { + dependsOn downloadFile + } +} + +tasks.whenTaskAdded { task -> + if (task.name == 'assembleDebug') { + task.dependsOn 'extractModels' + } + if (task.name == 'assembleRelease') { + task.dependsOn 'extractModels' + } +} + diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/AndroidManifest.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/BUILD rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/box_priors.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java similarity index 91% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java index bfb4a0a04bc90566736864bf62340d1032961858..580206943b303770419d1877012855a4e6bc3c2f 100644 --- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java @@ -25,6 +25,8 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; @@ -54,6 +56,14 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private static final float H_SCALE = 5.0f; private static final float W_SCALE = 5.0f; + // Float model + private static final float IMAGE_MEAN = 128.0f; + private static final float IMAGE_STD = 128.0f; + + //Number of threads in the java app + private static final int NUM_THREADS = 4; + + // Config values. private int inputSize; @@ -65,7 +75,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private float[][][] outputLocations; private float[][][] outputClasses; - float[][][][] img; + private ByteBuffer imgData = null; private Interpreter tfLite; @@ -176,9 +186,12 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { } // Pre-allocate buffers. - d.img = new float[1][inputSize][inputSize][3]; - + int numBytesPerChannel = 4; // Floating point + d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); + d.imgData.order(ByteOrder.nativeOrder()); d.intValues = new int[d.inputSize * d.inputSize]; + + d.tfLite.setNumThreads(NUM_THREADS); d.outputLocations = new float[1][NUM_RESULTS][4]; d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; return d; @@ -198,10 +211,11 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < inputSize; ++j) { - int pixel = intValues[j * inputSize + i]; - img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f; - img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f; - img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f; + int pixelValue = intValues[i * inputSize + j]; + // Float model + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); } } Trace.endSection(); // preprocessBitmap @@ -211,7 +225,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { outputLocations = new float[1][NUM_RESULTS][4]; outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; - Object[] inputArray = {img}; + Object[] inputArray = {imgData}; Map outputMap = new HashMap<>(); outputMap.put(0, outputLocations); outputMap.put(1, outputClasses); diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable/border.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/attrs.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/attrs.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/base-strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/base-strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle index 0d4de358156a5d139e35cc542b8d36ab24e763b9..a47fa4bbf6730c7d1269737564381c8464224713 100644 --- a/tensorflow/contrib/lite/examples/android/build.gradle +++ b/tensorflow/contrib/lite/examples/android/build.gradle @@ -1,52 +1,23 @@ -apply plugin: 'com.android.application' +// Top-level build file where you can add configuration options common to all sub-projects/modules. -android { - compileSdkVersion 26 - buildToolsVersion "26.0.1" - defaultConfig { - applicationId "org.tensorflow.lite.demo" - minSdkVersion 15 - targetSdkVersion 26 - versionCode 1 - versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" - - // Remove this block. - jackOptions { - enabled true - } - } - lintOptions { - abortOnError false - } - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' - } - } - aaptOptions { - noCompress "tflite" +buildscript { + repositories { + jcenter() } + dependencies { + classpath 'com.android.tools.build:gradle:3.0.1' - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files } } -repositories { - maven { - url 'https://google.bintray.com/tensorflow' +allprojects { + repositories { + jcenter() } } -dependencies { - compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { - exclude group: 'com.android.support', module: 'support-annotations' - }) - compile 'org.tensorflow:tensorflow-lite:+' - - testCompile 'junit:junit:4.12' +task clean(type: Delete) { + delete rootProject.buildDir } diff --git a/tensorflow/contrib/lite/examples/android/settings.gradle b/tensorflow/contrib/lite/examples/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/examples/minimal/BUILD b/tensorflow/contrib/lite/examples/minimal/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b403628d6c457ce3fb67eac3675fd7bb9187deab --- /dev/null +++ b/tensorflow/contrib/lite/examples/minimal/BUILD @@ -0,0 +1,27 @@ +# Description: +# TensorFlow Lite minimal example. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") + +tf_cc_binary( + name = "minimal", + srcs = [ + "minimal.cc", + ], + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc index 8b0ace96ccaf06ac1cbdc2ea95ac6e92ef886993..8b65cde7b79fde19280ad778ea874c64b01d169a 100644 --- a/tensorflow/contrib/lite/examples/minimal/minimal.cc +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/model.h" +#include #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" -#include +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/optional_debug_tools.h" // This is an example that is minimal to read a model // from disk and perform inference. There is no data being loaded @@ -29,14 +30,13 @@ limitations under the License. using namespace tflite; -#define TFLITE_MINIMAL_CHECK(x) \ - if(!(x)) { \ - fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ - exit(1); \ +#define TFLITE_MINIMAL_CHECK(x) \ + if (!(x)) { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + exit(1); \ } - -int main(int argc, char *argv[]) { +int main(int argc, char* argv[]) { if(argc != 2) { fprintf(stderr, "minimal \n"); return 1; @@ -44,8 +44,8 @@ int main(int argc, char *argv[]) { const char* filename = argv[1]; // Load model - std::unique_ptr model - = tflite::FlatBufferModel::BuildFromFile(filename); + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(filename); TFLITE_MINIMAL_CHECK(model != nullptr); // Build the interpreter @@ -57,12 +57,16 @@ int main(int argc, char *argv[]) { // Allocate tensor buffers. TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); + printf("=== Pre-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Fill input buffers // TODO(user): Insert code to fill input tensors // Run inference TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); + printf("\n\n=== Post-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Read output buffers // TODO(user): Insert getting data out code. diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index 50cc146a87ee9ab94aea6a92fb2fb5c531f83369..a591a353dd8f0ac94ecaa3f12e1aa1c57566ef69 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -7,6 +7,9 @@ no surprise that the APIs try to avoid unnecessary copies at the expense of convenience. Similarly, consistency with TensorFlow APIs was not an explicit goal and some variance is to be expected. +There is also a Python API for TensorFlow Lite described +[here](../toco/g3doc/python_api.md#interpreter). + ## C++ In order to run the inference model in TensorFlow Lite, one has to load the diff --git a/tensorflow/contrib/lite/g3doc/benchmarks.md b/tensorflow/contrib/lite/g3doc/benchmarks.md new file mode 100644 index 0000000000000000000000000000000000000000..29b087bea7aab1fcbc87ef764795f01e87b0bf9e --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/benchmarks.md @@ -0,0 +1,178 @@ +# Performance Benchmark numbers + +This document contains the performance benchmark numbers for running a few well +known models on some Android and iOS devices. + +The benchmark numbers were generated by running the [TFLite benchmark +binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) +on Android and running the [iOS benchmark +app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) +on iOS. + +# Android benchmarks + +When running Android benchmarks, the CPU affinity is set to use big cores on the +device to reduce variance (see +[details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). + +Models are assumed to have been downloaded from the link, unzipped and pushed to +`/data/local/tmp/tflite_models` folder. The benchmark binary is built according +to instructions listed +[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android). +and is assumed to have been pushed to `/data/local/tmp`. + +The following command was used to run the benchmark: + +``` +adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \ + --num_threads=1 \ + --graph=/data/local/tmp/tflite_models/${GRAPH} \ + --warmup_runs=1 \ + --num_runs=50 \ + --use_nnapi=false +``` + +where `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity +chosen according to the following table: + +Device | CPU_MASK | +-------| ---------- +Pixel 2 | f0 | +Pixel xl | 0c | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + Pixel 2 166.5 ms (2.6 ms)
Pixel xl 122.9 ms (1.8 ms)
+ Mobilenet_1.0_224 (quant) + Pixel 2 69.5 ms (0.9 ms)
Pixel xl 78.9 ms (2.2 ms)
+ NASNet mobile + Pixel 2 273.8 ms (3.5 ms)
Pixel xl 210.8 ms (4.2 ms)
+ SqueezeNet + Pixel 2 234.0 ms (2.1 ms)
Pixel xl 158.0 ms (2.1 ms)
+ Inception_ResNet_V2 + Pixel 2 2846.0 ms (15.0 ms)
Pixel xl 1973.0 ms (15.0 ms)
+ Inception_V4 + Pixel 2 3180.0 ms (11.7 ms)
Pixel xl 2262.0 ms (21.0 ms)
+ +# iOS benchmarks + +For running iOS benchmarks, the [benchmark +app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) +was modified to include the appropriate model and `benchmark_params.json` was +modified to set `num_threads` to 1. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + iPhone 8 32.2 ms (0.8 ms)
+ Mobilenet_1.0_224 (quant) + iPhone 8 24.4 ms (0.8 ms)
+ NASNet mobile + iPhone 8 60.3 ms (0.6 ms)
+ SqueezeNet + iPhone 8 44.3 (0.7 ms)
+ Inception_ResNet_V2 + iPhone 8562.4 ms (18.2 ms)
+ Inception_V4 + iPhone 8 661.0 ms (29.2 ms)
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 965273f0f04d33b52903c0551fff3533c31d3bd8..dcd17bbeabda08eaf86f8d5ac7f26cea0d3719a3 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -584,6 +584,31 @@ Options { } ``` +**RSQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise reciprocal square root of the input tensor +} +``` + +**SHAPE** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a 1D tensor representing the shape of the input tensor +} +Options { + out_type: the output type of the op (int32 or int64). Defaults to int32. +} +``` + **SLICE** ``` @@ -670,6 +695,17 @@ Options { } ``` +**SQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise square root of the input tensor +} +``` + **SQUEEZE** ``` @@ -742,6 +778,18 @@ Outputs { } ``` +**POW** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: elementwise pow of the input tensors +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 3287f9c4fdeeb8949e6fa15f4ec8c0aca2dd8a08..62a0b1ff0817d25bc8d4caaedf96d27c141b85ef 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -359,10 +359,13 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteBool: *bytes = sizeof(bool) * count; break; + case kTfLiteComplex64: + *bytes = sizeof(std::complex) * count; + break; default: ReportError(&context_, - "Only float32, int16, int32, int64, uint8, bool supported " - "currently."); + "Only float32, int16, int32, int64, uint8, bool, complex64 " + "supported currently."); return kTfLiteError; } return kTfLiteOk; @@ -605,9 +608,17 @@ TfLiteStatus Interpreter::Invoke() { } EnsureTensorsVectorCapacity(); + tensor_resized_since_op_invoke_ = false; if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } + + // Force execution prep for downstream ops if the latest op triggered the + // resize of a dynamic tensor. + if (tensor_resized_since_op_invoke_ && + HasDynamicTensor(context_, node.outputs)) { + next_execution_plan_index_to_prepare_ = execution_plan_index + 1; + } } if (!allow_buffer_handle_output_) { @@ -783,6 +794,8 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, if (tensor->allocation_type == kTfLiteArenaRw || tensor->allocation_type == kTfLiteDynamic || tensor->allocation_type == kTfLiteArenaRwPersistent) { + tensor_resized_since_op_invoke_ |= + TfLiteIntArrayEqual(tensor->dims, new_size) == 0; if (tensor->type != kTfLiteString) { size_t bytesRequired; TfLiteStatus status = BytesRequired(tensor->type, new_size->data, diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 37961cd1dc97607510edc9e6f0141c8bfc431c0d..033b8ee5fabc416fd5936b7ff69697235cd9e7e7 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -17,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ #define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#include #include #include #include @@ -39,6 +40,10 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt32; } template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt16; +} +template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt64; } @@ -54,6 +59,10 @@ template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteBool; } +template <> +constexpr TfLiteType typeToTfLiteType>() { + return kTfLiteComplex64; +} // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -393,6 +402,13 @@ class Interpreter { // WARNING: This is an experimental API and subject to change. TfLiteStatus ResetVariableTensorsToZero(); + // Retrieve an operator's description of its work, for profiling purposes. + const char* OpProfilingString(const TfLiteRegistration& op_reg, + const TfLiteNode* node) const { + if (op_reg.profiling_string == nullptr) return nullptr; + return op_reg.profiling_string(&context_, node); + } + private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. @@ -589,6 +605,11 @@ class Interpreter { bool allow_buffer_handle_output_ = false; + // Tracking bit for whether a tensor was resized in the course of an op + // invocation. This is a useful hint to ensure that dynamic tensor outputs + // trigger downstream reallocation after op invocation. + bool tensor_resized_since_op_invoke_ = false; + // Profiler for this interpreter instance. profiling::Profiler* profiler_; }; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index b977cb089c39e3904d1d9f83551fc401e82663d8..21cdf87d1e421868d1b62c5e23c2481cfbb4c989 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -23,6 +23,12 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { +namespace ops { +namespace builtin { +TfLiteRegistration* Register_PADV2(); +TfLiteRegistration* Register_NEG(); +} // namespace builtin +} // namespace ops namespace { // Make an interpreter that has no tensors and no nodes @@ -615,6 +621,59 @@ TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError); } +TEST(BasicInterpreter, DynamicTensorsResizeDescendants) { + // Assemble a graph with a node that has dynamically sized output (via the + // pad op), followed by a node with a standard element-wise op (negate). + Interpreter interpreter; + interpreter.AddTensors(4); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({3}); + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1}, + quant); + interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant); + interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant); + + TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2(); + TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG(); + interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op); + interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Configure [[2,2],[4,4]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 2; + interpreter.typed_tensor(1)[1] = 2; + interpreter.typed_tensor(1)[2] = 2; + interpreter.typed_tensor(1)[3] = 2; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Both the output and intermediate tensor sizes should reflect the output + // from the dynamic pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6); + + // Now configure [[4,4],[6,6]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 4; + interpreter.typed_tensor(1)[1] = 4; + interpreter.typed_tensor(1)[2] = 6; + interpreter.typed_tensor(1)[3] = 6; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Again, the output and intermediate tensor sizes should reflect the *new* + // resize from the latest pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14); +} + TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl index 4450bc9085555b3416f51bac07ea94a1240e919c..db837cf29edfc0ffe9950ffedc02cca1389b0fdf 100644 --- a/tensorflow/contrib/lite/java/aar_with_jni.bzl +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -1,5 +1,7 @@ """Generate zipped aar file including different variants of .so in jni folder.""" +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + def aar_with_jni(name, android_library): # Generate dummy AndroidManifest.xml for dummy apk usage # (dummy apk is generated by _dummy_app_for_so target below) @@ -19,7 +21,7 @@ EOF # Generate dummy apk including .so files and later we extract out # .so files and throw away the apk. - native.android_binary( + android_binary( name = name + "_dummy_app_for_so", manifest = name + "_generated_AndroidManifest.xml", custom_package = "dummy.package.for.so", diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle index 44ea2dcd908644bcfc637f71573ce722adaf6935..192162cfce787ffbf13e2b0db2da972116407888 100644 --- a/tensorflow/contrib/lite/java/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -5,7 +5,8 @@ android { buildToolsVersion "26.0.1" defaultConfig { applicationId "android.example.com.tflitecamerademo" - minSdkVersion 15 + // Required by Camera2 API. + minSdkVersion 21 targetSdkVersion 26 versionCode 1 versionName "1.0" diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD index d6fbef9cc938993b283103984307ab51e609dd6e..220d6c2159b56f6349e93132418fa0f6c69d1ab3 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index 362d93636f72205ddcda6d97fa9fae376ff211f1..f232b00045cf1df6a31ada80af4cc5885a4c0099 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -1,6 +1,8 @@ # Description: # OVIC Benchmarker Java API. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index 83974f4b337baedebaf9c9ffc0a03501418a3e36..a8d751ade26adc358e130138381eab9956f2d848 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + # Sample app for OVIC benchmarking. licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 2ae6c516b03ef4292667bbd944c73d2eeaf82db3..80de88b6a1cd75b033e116f76f5612ee66e48f03 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -311,8 +311,30 @@ final class NativeInterpreterWrapper implements AutoCloseable { return DataType.fromNumber(type).toStringName(); } + /** + * Gets the quantization zero point of an output. + * + * @throws IllegalArgumentExeption if the output index is invalid. + */ + int getOutputQuantizationZeroPoint(int index) { + return getOutputQuantizationZeroPoint(interpreterHandle, index); + } + + /** + * Gets the quantization scale of an output. + * + * @throws IllegalArgumentExeption if the output index is invalid. + */ + float getOutputQuantizationScale(int index) { + return getOutputQuantizationScale(interpreterHandle, index); + } + private static native int getOutputDataType(long interpreterHandle, int outputIdx); + private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx); + + private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx); + private static final int ERROR_BUFFER_SIZE = 512; private long errorHandle; diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 1fb6997fb9ba180e9a3f3a89a6d177086440c0d7..31f7b58fbc30cab9e6cb813094ea4b2627ba5cba 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -561,6 +561,38 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( return static_cast(type); } +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 0; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.zero_point); +} + +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 1.0f; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 1.0f; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.scale); +} + JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index eaa765cb343e9764bd0ef018d636a76f4b8a13e4..128ece49811a112684dac7b36810e920eeeb7351 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -152,6 +152,28 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( JNIEnv* env, jclass clazz, jlong handle, jint output_idx); +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JI)I + * + * Gets output quantization zero point. + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JI)F + * + * Gets output quantization scale. + */ +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 7c00d3196fd001a288d77d4e01f0b30978d72afe..9e41cb132d8386748e24c46d846e04f158d8b4c6 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -41,6 +41,9 @@ public final class NativeInterpreterWrapperTest { private static final String BYTE_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + private static final String QUANTIZED_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/quantized.bin"; + private static final String INVALID_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; @@ -536,4 +539,16 @@ public final class NativeInterpreterWrapperTest { assertThat(wrapper.getOutputDataType(0)).contains("byte"); wrapper.close(); } + + @Test + public void testGetOutputQuantizationParams() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f); + } + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f); + } + } } diff --git a/tensorflow/contrib/lite/java/src/testdata/quantized.bin b/tensorflow/contrib/lite/java/src/testdata/quantized.bin new file mode 100644 index 0000000000000000000000000000000000000000..4062088cdf717e8752490de5c9acff35fd6af54f Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/quantized.bin differ diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index b524246d436858bbf506809a38cead2897f78d93..af1d99ef41e6413d8ef2c6f478aaa8f9e3931ff8 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -1,6 +1,8 @@ # Description: # Internal helper function to test TF Lite API. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index cf5d0b4ce9cb3c516c185f31fea12db70a2c3bdb..61d5af3478474f006fe50cbbc9d2749127086c51 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -142,6 +142,7 @@ cc_library( "conv.cc", "depthwise_conv.cc", "dequantize.cc", + "detection_postprocess.cc", "div.cc", "elementwise.cc", "embedding_lookup.cc", @@ -157,16 +158,18 @@ cc_library( "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", - "mean.cc", "mfcc.cc", "mul.cc", "neg.cc", "pad.cc", "pooling.cc", + "pow.cc", + "reduce.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", + "shape.cc", "skip_gram.cc", "slice.cc", "space_to_batch_nd.cc", @@ -246,6 +249,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "detection_postprocess_test", + size = "small", + srcs = ["detection_postprocess_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + tf_cc_test( name = "activations_test", size = "small", @@ -554,9 +571,9 @@ tf_cc_test( ) tf_cc_test( - name = "mean_test", + name = "reduce_test", size = "small", - srcs = ["mean_test.cc"], + srcs = ["reduce_test.cc"], tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", @@ -979,6 +996,34 @@ tf_cc_test( ], ) +tf_cc_test( + name = "shape_test", + size = "small", + srcs = ["shape_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pow_test", + size = "small", + srcs = ["pow_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index add36b46c0b8a4deab1e842d50194c8b99a3a20c..99f81c4a8a78ab0b2a24955d77f25ed09da13b84 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -84,6 +84,38 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // These operators are implemented in fixed-point arithmetic, + // which intrinsically wants symmetric ranges (zero_point==0) + // and power-of-two scales (power-of-two is abbreviated below as POT). + // While more general support would be possible by means of rescaling, + // that would add some overhead and some loss of accuracy and wouldn't + // be used at the moment as current quantized LSTM applications are + // happy with symmetric, power-of-two-scales quantization. So we just + // implement that narrow case only for now. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // Support for shifts is limited until we have a parameterized version of + // SaturatingRoundingMultiplyByPOT(). + TF_LITE_ENSURE(context, data->input_left_shift >= 0); + TF_LITE_ENSURE(context, data->input_left_shift <= 1); } return context->ResizeTensor(context, output, @@ -114,6 +146,30 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // See comments in TanhPrepare about requiring zero_point==0 + // and a power-of-two ("POT") scale. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // The int16 logistic implementation does not support shifting of the input. + TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0); } return context->ResizeTensor(context, output, @@ -250,12 +306,19 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; } break; + case kTfLiteInt16: { + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), + data->input_left_shift, + GetTensorData(output), + GetTensorShape(output)); + return kTfLiteOk; + } break; case kTfLiteUInt8: { - optimized_ops::Tanh(GetTensorData(input), GetTensorDims(input), + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); return kTfLiteOk; } break; default: @@ -280,12 +343,18 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); break; } + case kTfLiteInt16: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); + break; + } case kTfLiteUInt8: { optimized_ops::Logistic( - GetTensorData(input), GetTensorDims(input), + GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, - GetTensorData(output), GetTensorDims(output)); + GetTensorData(output), GetTensorShape(output)); break; } default: @@ -341,26 +410,26 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; optimized_ops::Softmax(GetTensorData(input), - GetTensorDims({batch_size, 1, 1, input_size}), + GetTensorShape({batch_size, 1, 1, input_size}), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims({batch_size, 1, 1, input_size})); + GetTensorShape({batch_size, 1, 1, input_size})); } // Takes a 4D tensor and perform softmax along the forth dimension. void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), params->beta, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { @@ -415,8 +484,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: optimized_ops::LogSoftmax( - GetTensorData(input), GetTensorDims(input), - GetTensorData(output), GetTensorDims(output)); + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); return kTfLiteOk; default: context->ReportError(context, "Only float32 supported currently., got %d", diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 50a84edd475c8051a563cf8ed9fc03099829b786..587e1303da6afed1fc711100f457f1bf62b0b7e1 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -75,23 +75,42 @@ class FloatActivationsOpModel : public BaseActivationsOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; -// TODO(ahentz): I don't quite understand the tradeoffs in the quantized -// implementation of sigmoid and software, but a tolerance of twice the output -// scale seems reasonable. We might want to change this if we have a better -// theoretical bound. +// Our fixed-point math function implementations have roughly 12 bits of +// accuracy, when specialized to 16-bit fixed-point arithmetic. +// That is purely an implementation compromise, it would have been possible +// to get closer to 16 bits of accuracy but that would be more expensive, +// and not needed for our purposes as ultimately the output is either +// immediately down-quantized to 8 bits, or will typically be at the output +// of the surrounding LSTM cell. +// So we can require roughly 2^-12 accuracy when the output is 16-bit, and +// we can more or less expect the full 2^-8 accuracy when the output is 8-bit. +// +// However, the representable output interval is often [-1, 1] (it has to be +// for tanh, and even for logistic, when we implement it in fixed-point, we +// typically have to do so on such a symmetric interval, e.g. ARM NEON only +// has signed fixed-point arithmetic (SQRDMULH)). As the width of [-1, 1] +// is 2, our representable values are often diluted by a factor of 2, whence +// the factor of 2 below. const float kQuantizedTolerance = 2 * (1. / 256); +const float kQuantizedToleranceInt16 = 2 * (1. / 4096); class QuantizedActivationsOpModel : public BaseActivationsOpModel { public: using BaseActivationsOpModel::BaseActivationsOpModel; + template void SetInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + QuantizeAndPopulate(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + + std::vector GetOutput() { + return ExtractVector(output_); + } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -152,24 +171,47 @@ TEST(FloatActivationsOpTest, Tanh) { } TEST(QuantizedActivationsOpTest, Tanh) { + const float kMin = -1; + const float kMax = 127.f / 128.f; QuantizedActivationsOpModel m( BuiltinOperator_TANH, - /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8}, - /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1}); - m.SetInput({ + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ 0, -6, 2, 4, // -4, -2, 8, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.0, -0.999987, 0.964027, 0.999329, // - -0.996078, -0.96402, 0.99999, 0.76159, // + -0.999329, -0.96402, 0.99999, 0.76159, // }, - 4 * (1. / 256)))); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226})); + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225})); +} + +TEST(QuantizedActivationsOpTest, TanhInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_TANH, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + -4, -2, 8, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.0, -0.999987, 0.964027, 0.999329, // + -0.999329, -0.96402, 0.99999, 0.76159, // + }, + kQuantizedToleranceInt16))); } TEST(FloatActivationsOpTest, Sigmoid) { @@ -190,22 +232,43 @@ TEST(QuantizedActivationsOpTest, Sigmoid) { QuantizedActivationsOpModel m( BuiltinOperator_LOGISTIC, /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.5, 0.002473, 0.880797, 0.982014, // 0.952574, 0.119203, 0.999955, 0.731059, // }, kQuantizedTolerance))); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); } +TEST(QuantizedActivationsOpTest, SigmoidInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedToleranceInt16))); +} + TEST(FloatActivationsOpTest, Softmax4D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); @@ -241,12 +304,12 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m( 0.1, /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -258,21 +321,22 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m2( 0.1, /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); } TEST(FloatActivationsOpTest, Softmax2D) { @@ -309,12 +373,12 @@ TEST(FloatActivationsOpTest, Softmax2D) { TEST(QuantizedActivationsOpTest, Softmax2D) { QuantizedActivationsOpModel m(0.1, /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -325,21 +389,22 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { // Same input, but a different shape. QuantizedActivationsOpModel m2(0.1, /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); } // This contains the same test values as the Softmax test, but reference answer diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 443ce8924a43669fb264e19561c733d7e3436cb0..f44d531cbfa9ed41f881380752558555aab97b4d 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -39,6 +39,23 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // These fields are used in both the general 8-bit -> 8bit quantized path, + // and the special 16-bit -> 16bit quantized path + int input1_shift; + int input2_shift; + int32 output_activation_min; + int32 output_activation_max; + + // These fields are used only in the general 8-bit -> 8bit quantized path + int32 input1_multiplier; + int32 input2_multiplier; + int32 output_multiplier; + int output_shift; + int left_shift; + int32 input1_offset; + int32 input2_offset; + int32 output_offset; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +69,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -74,92 +92,169 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + // 8bit -> 8bit general quantized path, with general rescalings + data->input1_offset = -input1->params.zero_point; + data->input2_offset = -input2->params.zero_point; + data->output_offset = output->params.zero_point; + data->left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / + ((1 << data->left_shift) * output->params.scale); + + QuantizeMultiplierSmallerThanOneExp( + real_input1_multiplier, &data->input1_multiplier, &data->input1_shift); + data->input1_shift *= -1; + + QuantizeMultiplierSmallerThanOneExp( + real_input2_multiplier, &data->input2_multiplier, &data->input2_shift); + data->input2_shift *= -1; + + QuantizeMultiplierSmallerThanOneExp( + real_output_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; + + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + + } else if (output->type == kTfLiteInt16) { + // 16bit -> 16bit special quantized path, supporting only a rather + // narrow case of quantization parameters: zero_points must all be 0 + // ("symmetric quantization") and scales must be power-of-two (which + // we abbreviate as "POT" below). The intended use case for this path + // is in LSTM cells, where, due to the constraints of implementing + // some of the math in these LSTM cells in fixed-point arithmetic, + // we need to have such symmetric, power-of-two quantization + // (Fixed-point formats are inherently symmetric, power-of-two). + TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input1_scale_log2_rounded; + bool input1_scale_is_pot = + CheckedLog2(input1->params.scale, &input1_scale_log2_rounded); + TF_LITE_ENSURE(context, input1_scale_is_pot); + + int input2_scale_log2_rounded; + bool input2_scale_is_pot = + CheckedLog2(input2->params.scale, &input2_scale_log2_rounded); + TF_LITE_ENSURE(context, input2_scale_is_pot); + + int output_scale_log2_rounded; + bool output_scale_is_pot = + CheckedLog2(output->params.scale, &output_scale_log2_rounded); + TF_LITE_ENSURE(context, output_scale_is_pot); + + data->input1_shift = output_scale_log2_rounded - input1_scale_log2_rounded; + data->input2_shift = output_scale_log2_rounded - input2_scale_log2_rounded; + + // Shifting of one input is supported. The graph quantization should ensure + // that the other input matches the output. + TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0); + TF_LITE_ENSURE(context, data->input1_shift >= 0); + TF_LITE_ENSURE(context, data->input2_shift >= 0); + + CalculateActivationRangeQuantized(context, params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return context->ResizeTensor(context, output, output_size); } template -void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_ADD(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd); +void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_ADD(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t); + } else { + TF_LITE_ADD(reference_ops, Add, int32_t); + } } else { - TF_LITE_ADD(reference_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t); + } else { + TF_LITE_ADD(optimized_ops, Add, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd, float); + } else { + TF_LITE_ADD(reference_ops, Add, float); + } } else { - TF_LITE_ADD(optimized_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd, float); + } else { + TF_LITE_ADD(optimized_ops, Add, float); + } } } #undef TF_LITE_ADD } template -void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - const int left_shift = 20; - const double twice_max_input_scale = - 2 * std::max(input1->params.scale, input2->params.scale); - const double real_input1_multiplier = - input1->params.scale / twice_max_input_scale; - const double real_input2_multiplier = - input2->params.scale / twice_max_input_scale; - const double real_output_multiplier = - twice_max_input_scale / ((1 << left_shift) * output->params.scale); - - int32 input1_multiplier; - int input1_shift; - QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, - &input1_multiplier, &input1_shift); - input1_shift *= -1; - int32 input2_multiplier; - int input2_shift; - QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, - &input2_multiplier, &input2_shift); - input2_shift *= -1; - int32 output_multiplier; - int output_shift; - QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, - &output_multiplier, &output_shift); - output_shift *= -1; - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_ADD(type, opname) \ - type::opname(left_shift, GetTensorData(input1), \ - GetTensorDims(input1), input1_offset, input1_multiplier, \ - input1_shift, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, input2_multiplier, \ - input2_shift, output_offset, output_multiplier, output_shift, \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); - // The quantized version of Add doesn't support activations, so we - // always use BroadcastAdd. - if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops, BroadcastAdd); - } else { - TF_LITE_ADD(optimized_ops, BroadcastAdd); - } +TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, + TfLiteTensor* output) { + if (output->type == kTfLiteUInt8) { +#define TF_LITE_ADD(type, opname) \ + type::opname( \ + data->left_shift, GetTensorData(input1), GetTensorDims(input1), \ + data->input1_offset, data->input1_multiplier, data->input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + data->input2_offset, data->input2_multiplier, data->input2_shift, \ + data->output_offset, data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, BroadcastAdd); + } else { + TF_LITE_ADD(optimized_ops, BroadcastAdd); + } +#undef TF_LITE_ADD + } else if (output->type == kTfLiteInt16) { +#define TF_LITE_ADD(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + data->input1_shift, GetTensorData(input2), \ + GetTensorDims(input2), data->input2_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, Add); + } else { + TF_LITE_ADD(optimized_ops, Add); + } #undef TF_LITE_ADD + } + + return kTfLiteOk; } template @@ -171,15 +266,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalAddFloat(context, node, params, data, input1, input2, - output); - } else if (output->type == kTfLiteUInt8) { - EvalAddQuantized(context, node, params, data, input1, input2, - output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalAdd(context, node, params, data, input1, input2, output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK(context, + EvalAddQuantized(context, node, params, data, + input1, input2, output)); } else { context->ReportError(context, - "Inputs and outputs not all float|uint8 types."); + "Inputs and outputs not all float|uint8|int16 types."); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 956d05bed5162f6ce59705d59aad77ff056dda77..0b5844321133de103919de76d367574f018a6698 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -52,6 +52,13 @@ class FloatAddOpModel : public BaseAddOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + class QuantizedAddOpModel : public BaseAddOpModel { public: using BaseAddOpModel::BaseAddOpModel; @@ -60,15 +67,26 @@ class QuantizedAddOpModel : public BaseAddOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; // for quantized Add, the error shouldn't exceed 2*step -float GetTolerance(int min, int max) { +float GetTolerance(float min, float max) { float kQuantizedStep = (max - min) / 255.0; float kQuantizedTolerance = 2.0 * kQuantizedStep; return kQuantizedTolerance; } +float GetToleranceInt16(float min, float max) { + float kQuantizedStep = (max - min) / 32767.f; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + TEST(FloatAddOpModel, NoActivation) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -122,6 +140,57 @@ TEST(FloatAddOpModel, WithBroadcast) { } } +TEST(IntegerAddOpModel, NoActivation) { + IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 4, 10, 13})); +} + +TEST(IntegerAddOpModel, ActivationRELU_N1_TO_1) { + IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1})); +} + +TEST(IntegerAddOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerAddOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5, 11, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 04, 10, 13, 22, 21})) + << "With shape number " << i; + } +} + +TEST(IntegerAddOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerAddOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-19, 3, 8, 9, 12, 21}))) + << "With shape number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { @@ -144,6 +213,31 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { } } +TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 60770ca0aa8b85d9710d26beca3d4d603da5db2f..8dd48af57fd1bd9ef21256410d6bede6b7baa566 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" @@ -53,6 +54,20 @@ void copyCast(const FromT* in, ToT* out, int num_elements) { [](FromT a) { return static_cast(a); }); } +template +void copyCast(const std::complex* in, ToT* out, int num_elements) { + std::transform(in, in + num_elements, out, [](std::complex a) { + return static_cast(std::real(a)); + }); +} + +template <> +void copyCast(const std::complex* in, std::complex* out, + int num_elements) { + std::transform(in, in + num_elements, out, + [](std::complex a) { return a; }); +} + template TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, int num_elements) { @@ -72,6 +87,10 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, case kTfLiteBool: copyCast(in, out->data.b, num_elements); break; + case kTfLiteComplex64: + copyCast(in, reinterpret_cast*>(out->data.c64), + num_elements); + break; default: // Unsupported type. return kTfLiteError; @@ -95,6 +114,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return copyToTensor(input->data.f, output, num_elements); case kTfLiteBool: return copyToTensor(input->data.b, output, num_elements); + case kTfLiteComplex64: + return copyToTensor( + reinterpret_cast*>(input->data.c64), output, + num_elements); default: // Unsupported type. return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc index 53e20007378392467356ab29ecb8b217bb7a9e89..954f998206563a38c74a1382092851cfbee1013b 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" @@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) { ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); } +TEST(CastOpModel, CastComplex64ToFloat) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); +} + +TEST(CastOpModel, CastFloatToComplex64) { + CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 0.0f), std::complex(2.0f, 0.0f), + std::complex(3.0f, 0.0f), std::complex(4.0f, 0.0f), + std::complex(5.0f, 0.0f), std::complex(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToInt) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(CastOpModel, CastIntToComplex64) { + CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor(m.input(), {1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 0.0f), std::complex(2.0f, 0.0f), + std::complex(3.0f, 0.0f), std::complex(4.0f, 0.0f), + std::complex(5.0f, 0.0f), std::complex(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToComplex64) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), + std::complex(6.0f, 16.0f)})); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 14b399ef96eab1d5066a22a7eb95ab061e8ba2bc..0321b2e2a0088bdb09b2c3c61827be8064fe939b 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -179,9 +179,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); - bool hasBias = node->inputs->size == 3; + bool has_bias = node->inputs->size == 3; // Check number of inputs/outputs - TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; @@ -204,9 +204,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(ahentz): At this point the optimized versions require 'bias'. We can // either change that or document that convolution requires it. - TF_LITE_ENSURE(context, hasBias); + TF_LITE_ENSURE(context, has_bias); - if (hasBias) { + if (has_bias) { bias = &context->tensors[node->inputs->data[2]]; if (data_type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); @@ -226,29 +226,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Matching GetWindowedOutputSize in TensorFlow. auto padding = params->padding; - auto computeOutSize = [padding](int imageSize, int filterSize, int stride, - int dilationRate) -> int { - int effectiveFilterSize = (filterSize - 1) * dilationRate + 1; + auto compute_out_size = [padding](int image_size, int filter_size, int stride, + int dilation_rate) -> int { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; return padding == kTfLitePaddingSame - ? (imageSize + stride - 1) / stride + ? (image_size + stride - 1) / stride : padding == kTfLitePaddingValid - ? (imageSize - effectiveFilterSize + stride) / stride + ? (image_size - effective_filter_size + stride) / stride : 0; }; - int outWidth = computeOutSize(width, filter_width, params->stride_width, - params->dilation_width_factor); - int outHeight = computeOutSize(height, filter_height, params->stride_height, - params->dilation_height_factor); + int out_width = compute_out_size(width, filter_width, params->stride_width, + params->dilation_width_factor); + int out_height = + compute_out_size(height, filter_height, params->stride_height, + params->dilation_height_factor); data->padding.height = ComputePadding(params->stride_height, params->dilation_height_factor, - height, filter_height, outHeight); + height, filter_height, out_height); data->padding.width = ComputePadding(params->stride_width, params->dilation_width_factor, width, - filter_width, outWidth); + filter_width, out_width); - TF_LITE_ENSURE(context, hasBias); + TF_LITE_ENSURE(context, has_bias); // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. @@ -267,8 +268,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); output_size->data[0] = batches; - output_size->data[1] = outHeight; - output_size->data[2] = outWidth; + output_size->data[1] = out_height; + output_size->data[2] = out_width; output_size->data[3] = channels_out; auto output_status = context->ResizeTensor(context, output, output_size); @@ -308,18 +309,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* hwcn_weights = &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; hwcn_weights->type = data_type; - hwcn_weights->allocation_type = kTfLiteDynamic; - // Make sure we release any previous allocations before we reallocate. - // TODO(petewarden): Persistent arenas would be a better fit for this, but - // they aren't fully implemented yet. - if (hwcn_weights->data.raw) { - free(hwcn_weights->data.raw); - hwcn_weights->data.raw = nullptr; - } + hwcn_weights->allocation_type = kTfLiteArenaRwPersistent; - // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and - // ResizeTensor will actually allocate space for it. The would be more - // efficient if we placed hwcn_weights_status in the persistent arena. auto hwcn_weights_status = context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; @@ -381,8 +372,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); KernelType effective_kernel_type; if (((kernel_type == kMultithreadOptimized) || (kernel_type == kCblasOptimized)) && @@ -458,9 +449,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; - bool hasBias = node->inputs->size == 3; + bool has_bias = node->inputs->size == 3; TfLiteTensor* bias = - hasBias ? &context->tensors[node->inputs->data[2]] : nullptr; + has_bias ? &context->tensors[node->inputs->data[2]] : nullptr; TfLiteTensor* im2col = data->need_im2col ? &context->tensors[node->temporaries->data[data->im2col_index]] diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index a308de055f49eddba99d02e264fad11409a799f4..16e5f1d065d8ea6d187c5e368d6c9385fe62514b 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -173,8 +173,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); void (*depthwise_conv)(const float*, const Dims<4>&, const float*, const Dims<4>&, const float*, const Dims<4>&, int, int, diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c532cac5a9f59c8b09ff9aefc294e243561f027 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -0,0 +1,591 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace detection_postprocess { + +// Input tensors +constexpr int kInputTensorBoxEncodings = 0; +constexpr int kInputTensorClassPredictions = 1; +constexpr int kInputTensorAnchors = 2; + +// Output tensors +constexpr int kOutputTensorDetectionBoxes = 0; +constexpr int kOutputTensorDetectionClasses = 1; +constexpr int kOutputTensorDetectionScores = 2; +constexpr int kOutputTensorNumDetections = 3; + +constexpr size_t kNumCoordBox = 4; +constexpr size_t kBatchSize = 1; + +// Object Detection model produces axis-aligned boxes in two formats: +// BoxCorner represents the upper right (xmin, ymin) and +// lower left corner (xmax, ymax). +// CenterSize represents the center (xcenter, ycenter), height and width. +// BoxCornerEncoding and CenterSizeEncoding are related as follows: +// ycenter = y / y_scale * anchor.h + anchor.y; +// xcenter = x / x_scale * anchor.w + anchor.x; +// half_h = 0.5*exp(h/ h_scale)) * anchor.h; +// half_w = 0.5*exp(w / w_scale)) * anchor.w; +// ymin = ycenter - half_h +// ymax = ycenter + half_h +// xmin = xcenter - half_w +// xmax = xcenter + half_w +struct BoxCornerEncoding { + float ymin; + float xmin; + float ymax; + float xmax; +}; + +struct CenterSizeEncoding { + float y; + float x; + float h; + float w; +}; +// We make sure that the memory allocations are contiguous with static assert. +static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox, + "Size of BoxCornerEncoding is 4 float values"); +static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox, + "Size of CenterSizeEncoding is 4 float values"); + +struct OpData { + int max_detections; + int max_classes_per_detection; + float non_max_suppression_score_threshold; + float intersection_over_union_threshold; + int num_classes; + CenterSizeEncoding scale_values; + // Indices of Temporary tensors + int decoded_boxes_index; + int scores_index; + int active_candidate_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + op_data->max_detections = m["max_detections"].AsInt32(); + op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + op_data->non_max_suppression_score_threshold = + m["nms_score_threshold"].AsFloat(); + op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat(); + op_data->num_classes = m["num_classes"].AsInt32(); + op_data->scale_values.y = m["y_scale"].AsFloat(); + op_data->scale_values.x = m["x_scale"].AsFloat(); + op_data->scale_values.h = m["h_scale"].AsFloat(); + op_data->scale_values.w = m["w_scale"].AsFloat(); + context->AddTensors(context, 1, &op_data->decoded_boxes_index); + context->AddTensors(context, 1, &op_data->scores_index); + context->AddTensors(context, 1, &op_data->active_candidate_index); + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +// TODO(chowdhery): Add to kernel_util.h +TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor, + std::initializer_list values) { + TfLiteIntArray* size = TfLiteIntArrayCreate(values.size()); + int index = 0; + for (int v : values) { + size->data[index] = v; + ++index; + } + return context->ResizeTensor(context, tensor, size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* op_data = reinterpret_cast(node->user_data); + // Inputs: box_encodings, scores, anchors + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); + // number of detected boxes + const int num_detected_boxes = + op_data->max_detections * op_data->max_classes_per_detection; + + // Outputs: detection_boxes, detection_scores, detection_classes, + // num_detections + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + detection_boxes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_boxes, + {kBatchSize, num_detected_boxes, kNumCoordBox}); + + // Output Tensor detection_classes: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + detection_classes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes}); + + // Output Tensor detection_scores: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + detection_scores->type = kTfLiteFloat32; + SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes}); + + // Output Tensor num_detections: size is set to 1 + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + num_detections->type = kTfLiteFloat32; + // TODO (chowdhery): Make it a scalar when available + SetTensorSizes(context, num_detections, {1}); + + // Temporary tensors + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(3); + node->temporaries->data[0] = op_data->decoded_boxes_index; + node->temporaries->data[1] = op_data->scores_index; + node->temporaries->data[2] = op_data->active_candidate_index; + + // decoded_boxes + TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index]; + decoded_boxes->type = kTfLiteFloat32; + decoded_boxes->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, decoded_boxes, + {input_box_encodings->dims->data[1], kNumCoordBox}); + + // scores + TfLiteTensor* scores = &context->tensors[op_data->scores_index]; + scores->type = kTfLiteFloat32; + scores->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, scores, + {input_class_predictions->dims->data[1], + input_class_predictions->dims->data[2]}); + + // active_candidate + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + active_candidate->type = kTfLiteUInt8; + active_candidate->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, active_candidate, + {input_box_encodings->dims->data[1]}); + + return kTfLiteOk; +} + +class Dequantizer { + public: + Dequantizer(int zero_point, float scale) + : zero_point_(zero_point), scale_(scale) {} + float operator()(uint8 x) { + return (static_cast(x) - zero_point_) * scale_; + } + + private: + int zero_point_; + float scale_; +}; + +void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, + float quant_zero_point, float quant_scale, + CenterSizeEncoding* box_centersize) { + const uint8* boxes = + GetTensorData(input_box_encodings) + kNumCoordBox * idx; + Dequantizer dequantize(quant_zero_point, quant_scale); + box_centersize->y = dequantize(boxes[0]); + box_centersize->x = dequantize(boxes[1]); + box_centersize->h = dequantize(boxes[2]); + box_centersize->w = dequantize(boxes[3]); +} + +template +T ReInterpretTensor(const TfLiteTensor* tensor) { + // TODO (chowdhery): check float + const float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +template +T ReInterpretTensor(TfLiteTensor* tensor) { + // TODO (chowdhery): check float + float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, + OpData* op_data) { + // Parse input tensor boxencodings + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); + const int num_boxes = input_box_encodings->dims->data[1]; + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + + // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors + CenterSizeEncoding box_centersize; + CenterSizeEncoding scale_values = op_data->scale_values; + CenterSizeEncoding anchor; + for (int idx = 0; idx < num_boxes; ++idx) { + switch (input_box_encodings->type) { + // Quantized + case kTfLiteUInt8: + DequantizeBoxEncodings( + input_box_encodings, idx, + static_cast(input_box_encodings->params.zero_point), + static_cast(input_box_encodings->params.scale), + &box_centersize); + DequantizeBoxEncodings( + input_anchors, idx, + static_cast(input_anchors->params.zero_point), + static_cast(input_anchors->params.scale), &anchor); + break; + // Float + case kTfLiteFloat32: + box_centersize = ReInterpretTensor( + input_box_encodings)[idx]; + anchor = + ReInterpretTensor(input_anchors)[idx]; + break; + default: + // Unsupported type. + return kTfLiteError; + } + + float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y; + float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x; + float half_h = + 0.5f * static_cast(std::exp(box_centersize.h / scale_values.h)) * + anchor.h; + float half_w = + 0.5f * static_cast(std::exp(box_centersize.w / scale_values.w)) * + anchor.w; + TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + auto& box = ReInterpretTensor(decoded_boxes)[idx]; + box.ymin = ycenter - half_h; + box.xmin = xcenter - half_w; + box.ymax = ycenter + half_h; + box.xmax = xcenter + half_w; + } + return kTfLiteOk; +} + +void DecreasingPartialArgSort(const float* values, int num_values, + int num_to_sort, int* indices) { + std::iota(indices, indices + num_values, 0); + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); +} + +void SelectDetectionsAboveScoreThreshold(const std::vector& values, + const float threshold, + std::vector* keep_values, + std::vector* keep_indices) { + for (int i = 0; i < values.size(); i++) { + if (values[i] >= threshold) { + keep_values->emplace_back(values[i]); + keep_indices->emplace_back(i); + } + } +} + +bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) { + for (int i = 0; i < num_boxes; ++i) { + // ymax>=ymin, xmax>=xmin + auto& box = ReInterpretTensor(decoded_boxes)[i]; + if (box.ymin >= box.ymax || box.xmin >= box.xmax) { + return false; + } + } + return true; +} + +float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes, + const int i, const int j) { + auto& box_i = ReInterpretTensor(decoded_boxes)[i]; + auto& box_j = ReInterpretTensor(decoded_boxes)[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) return 0.0; + const float intersection_ymin = std::max(box_i.ymin, box_j.ymin); + const float intersection_xmin = std::max(box_i.xmin, box_j.xmin); + const float intersection_ymax = std::min(box_i.ymax, box_j.ymax); + const float intersection_xmax = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = + std::max(intersection_ymax - intersection_ymin, 0.0) * + std::max(intersection_xmax - intersection_xmin, 0.0); + return intersection_area / (area_i + area_j - intersection_area); +} + +// NonMaxSuppressionSingleClass() is O(n^2) pairwise comparison between boxes +// It assumes all boxes are good in beginning and sorts based on the scores. +// If lower-scoring box has too much overlap with a higher-scoring box, +// we get rid of the lower-scoring box. +TfLiteStatus NonMaxSuppressionSingleClassHelper( + TfLiteContext* context, TfLiteNode* node, OpData* op_data, + const std::vector& scores, std::vector* selected) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + const int num_boxes = input_box_encodings->dims->data[1]; + const int max_detections = op_data->max_detections; + const float non_max_suppression_score_threshold = + op_data->non_max_suppression_score_threshold; + const float intersection_over_union_threshold = + op_data->intersection_over_union_threshold; + // Maximum detections should be positive. + TF_LITE_ENSURE(context, (max_detections >= 0)); + // intersection_over_union_threshold should be positive + // and should be less than 1. + TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && + (intersection_over_union_threshold <= 1.0f)); + // Validate boxes + TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); + + // threshold scores + std::vector keep_indices; + // TODO (chowdhery): Remove the dynamic allocation and replace it + // with temporaries, esp for std::vector + std::vector keep_scores; + SelectDetectionsAboveScoreThreshold( + scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices); + + int num_scores_kept = keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept, + sorted_indices.data()); + + const int num_boxes_kept = num_scores_kept; + const int output_size = std::min(num_boxes_kept, max_detections); + selected->clear(); + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes); + int num_active_candidate = num_boxes_kept; + uint8_t* active_box_candidate = (active_candidate->data.uint8); + for (int row = 0; row < num_boxes_kept; row++) { + active_box_candidate[row] = 1; + } + + for (int i = 0; i < num_boxes_kept; ++i) { + if (num_active_candidate == 0 || selected->size() >= output_size) break; + if (active_box_candidate[i] == 1) { + selected->push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } else { + continue; + } + for (int j = i + 1; j < num_boxes_kept; ++j) { + if (active_box_candidate[j] == 1) { + float intersection_over_union = ComputeIntersectionOverUnion( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (intersection_over_union > intersection_over_union_threshold) { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + return kTfLiteOk; +} + +// This function implements a fast version of Non Maximal Suppression for +// multiple classes where +// 1) we keep the top-k scores for each anchor and +// 2) during NMS, each anchor only uses the highest class score for sorting. +// 3) Compared to standard NMS, the worst runtime of this version is O(N^2) +// instead of O(KN^2) where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int max_categories_per_anchor = op_data->max_classes_per_detection; + // The row index offset is 1 if background class is included and 0 otherwise. + const int label_offset = 1; + TF_LITE_ENSURE(context, (label_offset != -1)); + TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); + const int num_classes_with_background = num_classes + label_offset; + const int num_categories_per_anchor = + std::min(max_categories_per_anchor, num_classes); + std::vector max_scores; + max_scores.resize(num_boxes); + std::vector sorted_class_indices; + sorted_class_indices.resize(num_boxes * num_classes); + for (int row = 0; row < num_boxes; row++) { + const float* box_scores = + scores + row * num_classes_with_background + label_offset; + int* class_indices = sorted_class_indices.data() + row * num_classes; + DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, + class_indices); + max_scores[row] = box_scores[class_indices[0]]; + } + // Perform non-maximal suppression on max scores + std::vector selected; + NonMaxSuppressionSingleClassHelper(context, node, op_data, max_scores, + &selected); + // Allocate output tensors + int output_box_index = 0; + for (const auto& selected_index : selected) { + const float* box_scores = + scores + selected_index * num_classes_with_background + label_offset; + const int* class_indices = + sorted_class_indices.data() + selected_index * num_classes; + + for (int col = 0; col < num_categories_per_anchor; ++col) { + int box_offset = num_categories_per_anchor * output_box_index + col; + // detection_boxes + ReInterpretTensor(detection_boxes)[box_offset] = + ReInterpretTensor( + decoded_boxes)[selected_index]; + // detection_classes + detection_classes->data.f[box_offset] = class_indices[col]; + // detection_scores + detection_scores->data.f[box_offset] = box_scores[class_indices[col]]; + output_box_index++; + } + } + num_detections->data.f[0] = output_box_index; + return kTfLiteOk; +} + +void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, + const int num_boxes, + const int num_classes_with_background, + const TfLiteTensor* scores) { + float quant_zero_point = + static_cast(input_class_predictions->params.zero_point); + float quant_scale = static_cast(input_class_predictions->params.scale); + Dequantizer dequantize(quant_zero_point, quant_scale); + const uint8* scores_quant = GetTensorData(input_class_predictions); + for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { + scores->data.f[idx] = dequantize(scores_quant[idx]); + } +} + +TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, + TfLiteNode* node, OpData* op_data) { + // Get the input tensors + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], + kBatchSize); + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes); + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1)); + + const TfLiteTensor* scores; + switch (input_class_predictions->type) { + case kTfLiteUInt8: { + TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index]; + DequantizeClassPredictions(input_class_predictions, num_boxes, + num_classes_with_background, temporary_scores); + scores = temporary_scores; + } break; + case kTfLiteFloat32: + scores = input_class_predictions; + break; + default: + // Unsupported type. + return kTfLiteError; + } + NonMaxSuppressionMultiClassFastHelper(context, node, op_data, + GetTensorData(scores)); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(chowdhery): Generalize for any batch size + TF_LITE_ENSURE(context, (kBatchSize == 1)); + auto* op_data = reinterpret_cast(node->user_data); + // These two functions correspond to two blocks in the Object Detection model. + // In future, we would like to break the custom op in two blocks, which is + // currently not feasible because we would like to input quantized inputs + // and do all calculations in float. Mixed quantized/float calculations are + // currently not supported in TFLite. + + // This fills in temporary decoded_boxes + // by transforming input_box_encodings and input_anchors from + // CenterSizeEncodings to BoxCornerEncoding + DecodeCenterSizeBoxes(context, node, op_data); + // This fills in the output tensors + // by choosing effective set of decoded boxes + // based on Non Maximal Suppression, i.e. selecting + // highest scoring non-overlapping boxes. + NonMaxSuppressionMultiClass(context, node, op_data); + + return kTfLiteOk; +} +} // namespace detection_postprocess + +TfLiteRegistration* Register_DETECTION_POSTPROCESS() { + static TfLiteRegistration r = {detection_postprocess::Init, + detection_postprocess::Free, + detection_postprocess::Prepare, + detection_postprocess::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e0f8484a328d7d1668afd096ad3d08204fbb4a1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc @@ -0,0 +1,235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class BaseDetectionPostprocessOpModel : public SingleOpModel { + public: + BaseDetectionPostprocessOpModel(const TensorData& input1, + const TensorData& input2, + const TensorData& input3, + const TensorData& output1, + const TensorData& output2, + const TensorData& output3, + const TensorData& output4) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + input3_ = AddInput(input3); + output1_ = AddOutput(output1); + output2_ = AddOutput(output2); + output3_ = AddOutput(output3); + output4_ = AddOutput(output4); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("max_detections", 3); + fbb.Int("max_classes_per_detection", 1); + fbb.Float("nms_score_threshold", 0.0); + fbb.Float("nms_iou_threshold", 0.5); + fbb.Int("num_classes", 2); + fbb.Float("y_scale", 10.0); + fbb.Float("x_scale", 10.0); + fbb.Float("h_scale", 5.0); + fbb.Float("w_scale", 5.0); + }); + fbb.Finish(); + SetCustomOp("TFLite_Detection_PostProcess", fbb.GetBuffer(), + Register_DETECTION_POSTPROCESS); + BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + int input3() { return input3_; } + + template + void SetInput1(std::initializer_list data) { + PopulateTensor(input1_, data); + } + + template + void SetInput2(std::initializer_list data) { + PopulateTensor(input2_, data); + } + + template + void SetInput3(std::initializer_list data) { + PopulateTensor(input3_, data); + } + + template + std::vector GetOutput1() { + return ExtractVector(output1_); + } + + template + std::vector GetOutput2() { + return ExtractVector(output2_); + } + + template + std::vector GetOutput3() { + return ExtractVector(output3_); + } + + template + std::vector GetOutput4() { + return ExtractVector(output4_); + } + + std::vector GetOutputShape1() { return GetTensorShape(output1_); } + std::vector GetOutputShape2() { return GetTensorShape(output2_); } + std::vector GetOutputShape3() { return GetTensorShape(output3_); } + std::vector GetOutputShape4() { return GetTensorShape(output4_); } + + protected: + int input1_; + int input2_; + int input3_; + int output1_; + int output2_; + int output3_; + int output4_; +}; + +TEST(DetectionPostprocessOpTest, FloatTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_FLOAT32, {1, 6, 4}}, {TensorType_FLOAT32, {1, 6, 3}}, + {TensorType_FLOAT32, {6, 4}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + + // six boxes in center-size encoding + m.SetInput1({0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); + // class scores - two classes with background + m.SetInput2({0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., + .5, .4, 0., .3, .2}); + // six anchors in center-size encoding + m.SetInput3({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}); + // Same boxes in box-corner encoding: + // { 0.0, 0.0, 1.0, 1.0, + // 0.0, 0.1, 1.0, 1.1, + // 0.0, -0.1, 1.0, 0.9, + // 0.0, 10.0, 1.0, 11.0, + // 0.0, 10.1, 1.0, 11.1, + // 0.0, 100.0, 1.0, 101.0} + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 1e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} + +TEST(DetectionPostprocessOpTest, QuantizedTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0}, + {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0}, + {TensorType_UINT8, {6, 4}, 0.0, 100.5}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + // six boxes in center-size encoding + std::vector> inputs1 = { + {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}; + m.QuantizeAndPopulate(m.input1(), inputs1[0]); + // class scores - two classes with background + std::vector> inputs2 = { + {0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., .5, .4, 0., .3, + .2}}; + m.QuantizeAndPopulate(m.input2(), inputs2[0]); + // six anchors in center-size encoding + std::vector> inputs3 = { + {0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}}; + m.QuantizeAndPopulate(m.input3(), inputs3[0]); + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 3e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index d264821e30cf622ff5d3d8ad513add46caa9e7ae..bc5c3783fd63451fd6d600df2d8e93f740c68e95 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -83,8 +83,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_DIV(type, opname) \ type::opname(GetTensorData(input1), GetTensorDims(input1), \ GetTensorData(input2), GetTensorDims(input2), \ diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 98c21ce9d390aaa1f3cb5fdb8f31cbffb1b81d6a..59bab3c4ecd20bf938919ca606a5933f3112f233 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -64,6 +64,14 @@ TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { return Eval(context, node, std::log); } +TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sqrt); +} + +TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); }); +} + } // namespace elementwise TfLiteRegistration* Register_SIN() { @@ -78,6 +86,18 @@ TfLiteRegistration* Register_LOG() { return &r; } +TfLiteRegistration* Register_SQRT() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::SqrtEval}; + return &r; +} + +TfLiteRegistration* Register_RSQRT() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::RsqrtEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index 10e88d5a31868eeb5f65c7ade1f1c73827dea24a..ce4c602ee5c788d67701af3ecd3e023f2b25aae7 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -60,6 +60,24 @@ TEST(ElementWise, Log) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Sqrt) { + ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1, 1.41421, 2}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, Rsqrt) { + ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 2, 4, 9}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc index b755e8ce293442813b26ec3177162a3c95af2f89..50dc860e5a83f185abc70a844abdbc974f7bc4e7 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -39,7 +39,7 @@ class ExpandDimsOpModel : public SingleOpModel { void SetInputFloat(std::initializer_list data) { PopulateTensor(input_, data); } - void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } std::vector GetValuesFloat() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } @@ -51,7 +51,7 @@ class ExpandDimsOpModel : public SingleOpModel { TEST(ExpandDimsOpTest, DifferentAxis) { ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32); - const auto values = {-1.f, 1.f, -2.f, 2.f}; + std::initializer_list values = {-1.f, 1.f, -2.f, 2.f}; m.SetInputFloat(values); m.SetAxis(0); m.Invoke(); diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index f6fc0f5b6ad12d58c541efc6eae566ab4b8327f4..3b203dd480f95c5dc70a69aafce0bac6ab2cbc06 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -63,6 +63,7 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; +constexpr int kShuffledInputWorkspaceTensor = 1; constexpr int kScratchBufferTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -87,7 +88,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + // Shuffled formats need a workspace to store the shuffled input activations. + const int expected_outputs_count = + params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1 + : 2; + TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); @@ -121,9 +126,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { QuantizeMultiplierSmallerThanOneExp( real_multiplier, &data->output_multiplier, &data->output_shift); data->output_shift *= -1; - CalculateActivationRangeUint8(params->activation, output, - &data->output_activation_min, - &data->output_activation_max); + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); } // If we have to perform on-the-fly quantization (with quantized weights and @@ -278,44 +283,101 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t input_offset = -input->params.zero_point; int32_t filter_offset = -filter->params.zero_point; int32_t output_offset = output->params.zero_point; -#define TF_LITE_FULLY_CONNECTED(type) \ +#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \ type::FullyConnected( \ GetTensorData(input), GetTensorDims(input), input_offset, \ GetTensorData(filter), GetTensorDims(filter), filter_offset, \ GetTensorData(bias), GetTensorDims(bias), output_offset, \ data->output_multiplier, data->output_shift, \ data->output_activation_min, data->output_activation_max, \ - GetTensorData(output), GetTensorDims(output), gemm_context) + GetTensorData(output), GetTensorDims(output), \ + gemm_context) if (kernel_type == kReference) { - TF_LITE_FULLY_CONNECTED(reference_ops); - } else if (kernel_type == kPie) { - if (input->type == kTfLiteFloat32) { - // Pie currently only supports quantized models and float inputs/outputs. - TfLiteTensor* input_quantized = - &context->tensors[node->temporaries->data[0]]; - return EvalPieQuantized(context, node, params, data, input, filter, bias, - input_quantized, output); - } else { - // TODO(ahentz): we don't have a quantized version of the PIE kernels, so - // we just defer to the MINI ones. - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; } + } else if (kernel_type == kPie && input->type == kTfLiteFloat32) { + // Pie currently only supports quantized models and float inputs/outputs. + TfLiteTensor* input_quantized = + &context->tensors[node->temporaries->data[0]]; + return EvalPieQuantized(context, node, params, data, input, filter, bias, + input_quantized, output); } else { - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } } #undef TF_LITE_FULLY_CONNECTED return kTfLiteOk; } +template +TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, + OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output, + TfLiteTensor* shuffled_input_workspace) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + // TODO(b/110697972) decide more consistently if / how / where we want + // to perform this kind of runtime data type checks. + if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 || + bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 || + shuffled_input_workspace->type != kTfLiteUInt8) { + context->ReportError(context, "Unexpected data type"); + return kTfLiteError; + } + +#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ + type::ShuffledFullyConnected( \ + GetTensorData(input), GetTensorDims(input), \ + GetTensorData(filter), GetTensorDims(filter), \ + GetTensorData(bias), GetTensorDims(bias), \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output), \ + GetTensorData(shuffled_input_workspace), gemm_context) + if (kernel_type == kReference) { + TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops); + } else { + TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_SHUFFLED_FULLY_CONNECTED + + return kTfLiteOk; +} + template TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_FULLY_CONNECTED(type) \ type::FullyConnected(GetTensorData(input), GetTensorDims(input), \ GetTensorData(filter), GetTensorDims(filter), \ @@ -352,8 +414,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return EvalFloat(context, node, params, data, input, filter, bias, output); case kTfLiteUInt8: - return EvalQuantized(context, node, params, data, input, - filter, bias, output); + if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) { + TfLiteTensor* shuffled_input_workspace = + GetOutput(context, node, kShuffledInputWorkspaceTensor); + return EvalShuffledQuantized(context, node, params, data, + input, filter, bias, output, + shuffled_input_workspace); + } else if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatDefault) { + return EvalQuantized(context, node, params, data, input, + filter, bias, output); + } else { + context->ReportError(context, + "Unhandled fully-connected weights format"); + return kTfLiteError; + } default: context->ReportError(context, "Type %d not currently supported.", filter->type); diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index 05dd028b484c09bdf90a09fab1238f48e8a9ddab..ec949056971ccb5f7a6f93fa9f236a93625ca6ad 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -15,6 +15,7 @@ limitations under the License. // Unit test for TFLite FULLY_CONNECTED op. #include +#include #include #include @@ -133,9 +134,12 @@ static float fully_connected_golden_output[] = { class BaseFullyConnectedOpModel : public SingleOpModel { public: // TODO(ahentz): test different activation types too. - BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units, - int batches, const TensorData& input, - const TensorData& output = {TensorType_FLOAT32}) + BaseFullyConnectedOpModel( + TfLiteRegistration* registration, int units, int batches, + const TensorData& input, const TensorData& output = {TensorType_FLOAT32}, + ActivationFunctionType activation_func = ActivationFunctionType_RELU, + FullyConnectedOptionsWeightsFormat weights_format = + FullyConnectedOptionsWeightsFormat_DEFAULT) : batches_(batches), units_(units) { int total_input_size = 1; for (int i = 0; i < input.shape.size(); ++i) { @@ -159,10 +163,13 @@ class BaseFullyConnectedOpModel : public SingleOpModel { } output_ = AddOutput(output); + if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) { + AddOutput({TensorType_UINT8, input.shape}); + } SetBuiltinOp( BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + CreateFullyConnectedOptions(builder_, activation_func, weights_format) .Union()); resolver_ = absl::make_unique( BuiltinOperator_FULLY_CONNECTED, registration); @@ -188,13 +195,11 @@ class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { public: using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } - void SetWeights(std::initializer_list f) { - PopulateTensor(weights_, f); - } + void SetWeights(const std::vector& f) { PopulateTensor(weights_, f); } - void SetInput(std::initializer_list data) { + void SetInput(const std::vector& data) { PopulateTensor(input_, data); } void SetInput(int offset, float* begin, float* end) { @@ -208,20 +213,50 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { public: using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; - void SetBias(std::initializer_list data) { + void SetBias(const std::vector& data) { QuantizeAndPopulate(bias_, data); } - void SetWeights(std::initializer_list data) { + void SetWeights(const std::vector& data) { QuantizeAndPopulate(weights_, data); } - void SetInput(std::initializer_list data) { + void ShuffleAndSetWeights(const std::vector& data, int input_depth, + int output_depth) { + std::vector shuffled_data(data.size()); + CHECK_EQ(input_depth % 16, 0); + CHECK_EQ(output_depth % 4, 0); + float* shuffled_data_ptr = shuffled_data.data(); + for (int block_o = 0; block_o < output_depth; block_o += 4) { + for (int block_i = 0; block_i < input_depth; block_i += 16) { + for (int o = 0; o < 4; o++) { + for (int i = 0; i < 16; i++) { + *shuffled_data_ptr++ = + data[(block_o + o) * input_depth + block_i + i]; + } + } + } + } + TfLiteTensor* t = interpreter_->tensor(weights_); + auto quantized_data = + Quantize(shuffled_data, t->params.scale, t->params.zero_point); + for (uint8_t& q : quantized_data) { + q ^= 0x80; + } + PopulateTensor(weights_, 0, quantized_data.data(), + quantized_data.data() + quantized_data.size()); + } + void SetInput(const std::vector& data) { QuantizeAndPopulate(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -256,12 +291,12 @@ class HybridFullyConnectedOpModel : public SingleOpModel { ops::builtin::Register_FULLY_CONNECTED_PIE()); BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); } - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } - void SetWeights(std::initializer_list data) { + void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } + void SetWeights(const std::vector& data) { SymmetricQuantizeAndPopulate(weights_, data); } - void SetInput(std::initializer_list f) { PopulateTensor(input_, f); } + void SetInput(const std::vector& f) { PopulateTensor(input_, f); } std::vector GetOutput() { return ExtractVector(output_); } int input_size() { return input_size_; } @@ -340,6 +375,24 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } +TEST_P(FloatFullyConnectedOpTest, SimpleTest2) { + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}); + m.SetWeights({ + 2, 4, // u = 0 + }); + m.SetBias({1}); + + m.SetInput({ + 1, 2, // b = 0 + 2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9)); +} + TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, @@ -350,7 +403,7 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }); m.SetBias({1, 2, 3}); @@ -361,11 +414,136 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ - 24, 25, 26, // - 58, 59, 60, // - }))); - EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(151, 152, 153, 185, 186, 187)); +} + +void SimpleTestQuantizedInt16OutputCase( + TfLiteRegistration* registration, int input_depth, int output_depth, + int batches, FullyConnectedOptionsWeightsFormat weights_format) { + const uint8_t kWeightsZeroPoint = 128; + const float kWeightsScale = 1.f / 128.f; + const uint8_t kInputZeroPoint = 128; + const float kInputScale = 1.f / 128.f; + const float kInputMin = (0 - kInputZeroPoint) * kInputScale; + const float kInputMax = (255 - kInputZeroPoint) * kInputScale; + // Output ranges in [-8..8] encoded as int16 + const float kOutputScale = 8.f / 32768.f; + const float kOutputMin = -32768 * kOutputScale; + const float kOutputMax = 32767 * kOutputScale; + + QuantizedFullyConnectedOpModel m( + registration, output_depth, batches, + /*input=*/ + {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax}, + /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax}, + /*activation_func=*/ActivationFunctionType_NONE, weights_format); + + std::mt19937 random_engine; + std::uniform_int_distribution weights_dist; + + std::vector weights_data(input_depth * output_depth); + for (auto& w : weights_data) { + uint8_t q = weights_dist(random_engine); + w = (q - kWeightsZeroPoint) * kWeightsScale; + } + + // Based on weights_format, enforce any shape requirement for that format/path + // and set the (possibly shuffled) weights. + switch (weights_format) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + m.SetWeights(weights_data); + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + // The shuffled path currently supports only a restrictive subset of + // shapes, described by the following assertions: + CHECK_EQ(input_depth % 16, 0); + CHECK_EQ(output_depth % 4, 0); + CHECK(batches == 1 || batches == 4); + m.ShuffleAndSetWeights(weights_data, input_depth, output_depth); + break; + default: + LOG(FATAL) << "Unhandled weights format"; + } + + std::uniform_int_distribution input_dist; + std::vector input_data(input_depth * batches); + for (auto& i : input_data) { + uint8_t q = input_dist(random_engine); + i = (q - kInputZeroPoint) * kInputScale; + } + + std::vector bias_data(output_depth); + // As the output ranges in [-8, 8], it's reasonable to have bias values + // in [-1, 1], this won't result in too much saturation. + std::uniform_real_distribution bias_dist(-1.f, 1.f); + for (auto& b : bias_data) { + b = bias_dist(random_engine); + } + + m.SetBias(bias_data); + m.SetInput(input_data); + + m.Invoke(); + + std::vector expected_output_data(output_depth * batches); + for (int b = 0; b < batches; b++) { + for (int o = 0; o < output_depth; o++) { + float accum = bias_data[o]; + for (int i = 0; i < input_depth; i++) { + accum += + input_data[b * input_depth + i] * weights_data[o * input_depth + i]; + } + accum = std::min(accum, kOutputMax); + accum = std::max(accum, kOutputMin); + expected_output_data[b * output_depth + o] = accum; + } + } + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f))); +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedInt16OutputDefaultWeights) { + for (int input_depth : {1, 3, 10, 100}) { + for (int output_depth : {1, 3, 10, 100}) { + for (int batch : {1, 3, 10, 100}) { + SimpleTestQuantizedInt16OutputCase( + GetRegistration(), input_depth, output_depth, batch, + FullyConnectedOptionsWeightsFormat_DEFAULT); + } + } + } +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) { + // The shuffled weights block shape is 4x16. The shape of the weights matrix + // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16. + // This means that output_depth must be a multiple of 4, and input_deth must + // be a multiple of 16. + for (int input_depth_numblocks : {1, 3}) { + for (int output_depth_numblocks : {1, 3}) { + int input_depth = 16 * input_depth_numblocks; + int output_depth = 4 * output_depth_numblocks; + // The fast shuffled path is currently supporting only batch sizes of 1 + // and 4. The idea is that the whole point of that path is to go as fast + // as possible for small batch size, which requires fully specializing + // it for each batch size, and for larger batch sizes the generic + // gemmlowp-based implementation is fast enough. + for (int batch : {1, 4}) { + SimpleTestQuantizedInt16OutputCase( + GetRegistration(), input_depth, output_depth, batch, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8); + } + } + } } TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) { @@ -396,11 +574,11 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) { /*max_abs_error=*/1.3f))); } -TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) { +TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { // Note that it is not required that the first dimension be the number of // batches. All we care is that the input can be evenly distributed in // batches. In this case, we need the input to have multiples of '2'. - FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(), + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); m.SetWeights({ @@ -444,11 +622,13 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) { m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ - 24, 25, 26, // - 58, 59, 60, // - }))); - EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(151, 152, 153, 185, 186, 187)); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc index e786f785abe3aa66a9fb243dd4f332ca91676863..d2f1103e14b40b81c59c8053bcdbee30c85e5c78 100644 --- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -32,19 +32,21 @@ namespace tflite { namespace { void RunLogSoftmaxFloatReference(const uint8* input_data, - const Dims<4>& dims_common, int32 input_offset, - const double input_scale, int stride, - float beta, uint8* reference_output_data) { - const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); // Reference data generated via Dequant of input into float, and then applying // float LogSoftmax. - reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, - reference_dequant_data.data(), dims_common); - optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common, - reference_output_float_data.data(), dims_common); + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common, + reference_output_float_data.data(), shape_common); // Work with quantized scaling for LogSoftmax, under which 255 represents 0, // and -16 gets nudged up to 0. for (int i = 0; i < ref_buffer_size; i++) { @@ -55,9 +57,9 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, } void CheckOutputData(const uint8* test_output, const uint8* reference_output, - const Dims<4>& dims_common, const string& check_label, - bool be_exacting) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); // While calculating some metrics in floating point, we work with quantized // scaling. std::vector diff(buffer_size); @@ -99,15 +101,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output, // Runs the LogSoftmax and compares against the float reference implementation // and the quantized reference implementation. -void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, - int32 input_offset, const double input_scale, - int stride, float beta) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); +void RunOneLogSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); std::vector optimized_logsoftmax_output(buffer_size); std::vector reference_float_logsoftmax_output(buffer_size); std::vector reference_quant_logsoftmax_output(buffer_size); - RunLogSoftmaxFloatReference(input_data, dims_common, input_offset, + RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_logsoftmax_output.data()); @@ -126,23 +128,23 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_beta_left_shift); - optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier, + optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - optimized_logsoftmax_output.data(), dims_common); + optimized_logsoftmax_output.data(), shape_common); reference_ops::LogSoftmax( - input_data, dims_common, input_beta_multiplier, input_beta_left_shift, + input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - reference_quant_logsoftmax_output.data(), dims_common); + reference_quant_logsoftmax_output.data(), shape_common); CheckOutputData(optimized_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Optimized vs float reference", false); CheckOutputData(optimized_logsoftmax_output.data(), - reference_quant_logsoftmax_output.data(), dims_common, + reference_quant_logsoftmax_output.data(), shape_common, "Optimized vs quant reference", true); CheckOutputData(reference_quant_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Quant reference vs float reference", false); } @@ -165,13 +167,13 @@ bool TryOneUniformLogSoftmax() { const int32 input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandom(&input_data); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } @@ -203,14 +205,14 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { const int middle_min = UniformRandomInt(0, 255); const int sides_max = UniformRandomInt(0, middle_min); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h index c0dda4acf1a59de22d07905f2a2e2bbe422e8d21..7816752132761d9523ffc1f45b3740c0817ed402 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -26,6 +26,10 @@ limitations under the License. namespace tflite { namespace optimized_ops { +// Unoptimized reference ops: +using reference_ops::Relu1; +using reference_ops::Relu6; + inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { return RuntimeShape( {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); @@ -34,15 +38,285 @@ inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { template void L2Normalization(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - return L2Normalization(input_data, DimsToShape(input_dims), output_data, - DimsToShape(output_dims)); + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, int32 input_zero_point, uint8* output_data, const Dims<4>& output_dims) { - return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, - output_data, DimsToShape(output_dims)); + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); } } // namespace optimized_ops diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index d0008cc4fb62c2105d6817a6e44cefa974a31dbc..8597707b24325588b1b4dc4f4ac68ccfa9cecd36 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -40,16 +40,30 @@ namespace tflite { namespace optimized_ops { // Unoptimized reference ops: +using reference_ops::ArgMax; using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; using reference_ops::BroadcastLess; using reference_ops::BroadcastLessEqual; +using reference_ops::Concatenation; +using reference_ops::DepthConcatenation; +using reference_ops::Dequantize; +using reference_ops::Div; +using reference_ops::FakeQuant; +using reference_ops::Gather; using reference_ops::Greater; using reference_ops::GreaterEqual; using reference_ops::Less; using reference_ops::LessEqual; +using reference_ops::Mean; using reference_ops::RankOneSelect; +using reference_ops::Relu1; +using reference_ops::Relu6; +using reference_ops::ReluX; using reference_ops::Select; +using reference_ops::SpaceToBatchND; +using reference_ops::StridedSlice; +using reference_ops::Transpose; // TODO(b/80247582) Remove this constant. // This will be phased out as the shifts are revised with more thought. Use of a @@ -72,6 +86,12 @@ using VectorMap = typename std::conditional< Eigen::Dynamic, 1>>, Eigen::Map>>::type; +template +VectorMap MapAsVector(Scalar* data, const RuntimeShape& shape) { + const int size = shape.FlatSize(); + return VectorMap(data, size, 1); +} + template VectorMap MapAsVector(Scalar* data, const Dims& dims) { const int size = FlatSize(dims); @@ -88,6 +108,23 @@ using MatrixMap = typename std::conditional< Eigen::Dynamic, Eigen::Dynamic>>, Eigen::Map>>::type; +template +MatrixMap MapAsMatrixWithLastDimAsRows(Scalar* data, + const RuntimeShape& shape) { + const int dims_count = shape.DimensionsCount(); + const int rows = shape.Dims(dims_count - 1); + const int cols = FlatSizeSkipDim(shape, dims_count - 1); + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithFirstDimAsCols(Scalar* data, + const RuntimeShape& shape) { + const int cols = shape.Dims(0); + const int rows = FlatSizeSkipDim(shape, 0); + return MatrixMap(data, rows, cols); +} + template MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, const Dims& dims) { @@ -134,16 +171,9 @@ template MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, const Dims& dims, int rows) { - int cols = 1; - bool matched_rows = false; - for (int d = 0; d < N; d++) { - cols *= dims.sizes[d]; - if (cols == rows) { - matched_rows = true; - cols = 1; - } - } - TFLITE_DCHECK(matched_rows); + const int flatsize = FlatSize(dims); + TFLITE_DCHECK((flatsize % rows) == 0); + const int cols = flatsize / rows; return MatrixMap(data, rows, cols); } @@ -1256,11 +1286,11 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } // Internal function doing the actual arithmetic work for -// ExperimentalShuffledFullyConnected. +// ShuffledFullyConnected. // May be called either directly by it (single-threaded case) or may be used // as the 'task' for worker threads to run (multi-threaded case, see -// ExperimentalShuffledFullyConnectedWorkerTask below). -inline void ExperimentalShuffledFullyConnectedWorkerImpl( +// ShuffledFullyConnectedWorkerTask below). +inline void ShuffledFullyConnectedWorkerImpl( const uint8* shuffled_input_workspace_data, const int8* shuffled_weights_data, int batches, int output_depth, int output_stride, int accum_depth, const int32* bias_data, @@ -1534,14 +1564,16 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl( #endif } -// Wraps ExperimentalShuffledFullyConnectedWorkerImpl into a Task class +// Wraps ShuffledFullyConnectedWorkerImpl into a Task class // to allow using gemmlowp's threadpool. -struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { - ExperimentalShuffledFullyConnectedWorkerTask( - const uint8* input_data, const int8* shuffled_weights_data, int batches, - int output_depth, int output_stride, int accum_depth, - const int32* bias_data, int32 output_multiplier, int output_shift, - int16* output_data) +struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task { + ShuffledFullyConnectedWorkerTask(const uint8* input_data, + const int8* shuffled_weights_data, + int batches, int output_depth, + int output_stride, int accum_depth, + const int32* bias_data, + int32 output_multiplier, int output_shift, + int16* output_data) : input_data_(input_data), shuffled_weights_data_(shuffled_weights_data), batches_(batches), @@ -1554,7 +1586,7 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { output_data_(output_data) {} void Run() override { - ExperimentalShuffledFullyConnectedWorkerImpl( + ShuffledFullyConnectedWorkerImpl( input_data_, shuffled_weights_data_, batches_, output_depth_, output_stride_, accum_depth_, bias_data_, output_multiplier_, output_shift_, output_data_); @@ -1572,15 +1604,14 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { int16* output_data_; }; -inline void ExperimentalShuffledFullyConnected( +inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_activation_max, int16* output_data, const Dims<4>& output_dims, uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) { - gemmlowp::ScopedProfilingLabel label( - "ExperimentalShuffledFullyConnected/8bit"); + gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit"); (void)gemm_context; // only used in optimized code. TFLITE_DCHECK_EQ(output_activation_min, -32768); TFLITE_DCHECK_EQ(output_activation_max, 32767); @@ -1664,7 +1695,7 @@ inline void ExperimentalShuffledFullyConnected( if (thread_count == 1) { // Single-thread case: do the computation on the current thread, don't // use a threadpool - ExperimentalShuffledFullyConnectedWorkerImpl( + ShuffledFullyConnectedWorkerImpl( shuffled_input_workspace_data, int8_shuffled_weights_data, batches, output_depth, output_depth, accum_depth, bias_data, output_multiplier, output_shift, output_data); @@ -1679,7 +1710,7 @@ inline void ExperimentalShuffledFullyConnected( int row_start = 0; for (int i = 0; i < thread_count; i++) { int row_end = std::min(output_depth, row_start + kRowsPerWorker); - tasks[i] = new ExperimentalShuffledFullyConnectedWorkerTask( + tasks[i] = new ShuffledFullyConnectedWorkerTask( shuffled_input_workspace_data, int8_shuffled_weights_data + row_start * accum_depth, batches, row_end - row_start, output_depth, accum_depth, bias_data + row_start, @@ -2330,41 +2361,15 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); - const auto input = MapAsVector(input_data, input_dims); - auto output = MapAsVector(output_data, output_dims); + const auto input = MapAsVector(input_data, input_shape); + auto output = MapAsVector(output_data, output_shape); output = input.cwiseMax(0.0f); } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 1; - const float lower = -1; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 6; - const float lower = 0; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - template void L2Normalization(const float* input_data, const RuntimeShape& input_shape, float* output_data, const RuntimeShape& output_shape) { @@ -2671,25 +2676,13 @@ inline void Add(int left_shift, const uint8* input1_data, output_activation_max, output_data); } -template inline void Add(const int16* input1_data, const Dims<4>& input1_dims, int input1_shift, const int16* input2_data, const Dims<4>& input2_dims, int input2_shift, int16 output_activation_min, int16 output_activation_max, int16* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Add/Int16"); - // This is a copy of the reference implementation. We do not currently have a - // properly optimized version. - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); @@ -2715,6 +2708,42 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } +inline void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32 output_activation_min, int32 output_activation_max, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] + input2_data[i], output_activation_min, + output_activation_max); + } +} + +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, + input2_shift, output_activation_min, output_activation_max, output_data, + output_dims); +} + template void Add(const int32* input1_data, const Dims<4>& input1_dims, const int32* input2_data, const Dims<4>& input2_dims, @@ -3215,19 +3244,6 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } -// TODO(aselle): This is not actually optimized yet. -inline void Div(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - for (int i = 0; i < flat_size; i++) { - output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] / input2_data[i], output_activation_min, - output_activation_max); - } -} - // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -3393,105 +3409,6 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, } } -template -void Concatenation(int concat_dim, const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Concatenation"); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - // for now we dont have a model with a Concatenation - // with fused activation function. - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - Scalar* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - memcpy(output_ptr, input_data[i] + k * copy_size, - copy_size * sizeof(Scalar)); - output_ptr += copy_size; - } - } -} - -// TODO(prabhumk): This is the same as the reference implementation. -// TODO(prabhumk): The quantized implementation of concatentation isn't fully -// quantized as it takes scale as a floating point value. This should be fixed -// when optimizng this routine further. -inline void Concatenation(int concat_dim, const uint8* const* input_data, - const Dims<4>* const* input_dims, - const int32* input_zeropoint, - const float* input_scale, int inputs_count, - uint8* output_data, const Dims<4>& output_dims, - const int32 output_zeropoint, - const float output_scale) { - // The arguments input_zeropoint and input_scale are expected to be an array - // that have the quantization parameters for all the inputs to the concat - // operator. - gemmlowp::ScopedProfilingLabel label("Concatenation"); - TFLITE_DCHECK_GT(inputs_count, 1); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - const float inverse_output_scale = 1.f / output_scale; - uint8* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - const uint8* input_ptr = input_data[i] + k * copy_size; - if (input_zeropoint[i] == output_zeropoint && - input_scale[i] == output_scale) { - memcpy(output_ptr, input_ptr, copy_size); - } else { - const float scale = input_scale[i] * inverse_output_scale; - const float bias = -input_zeropoint[i] * scale; - for (int j = 0; j < copy_size; ++j) { - const int32_t value = - static_cast(round(input_ptr[j] * scale + bias)) + - output_zeropoint; - output_ptr[j] = - static_cast(std::max(std::min(255, value), 0)); - } - } - output_ptr += copy_size; - } - } -} - -template -void DepthConcatenation(const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - Concatenation(0, input_data, input_dims, inputs_count, - output_data, output_dims); -} - inline void LstmCell(const float* input_data, const Dims<4>& input_dims, const float* prev_activ_data, const Dims<4>& prev_activ_dims, const float* weights_data, @@ -3854,23 +3771,25 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, - float output_activation_min, +inline void AveragePool(const float* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float output_activation_min, float output_activation_max, float* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("AveragePool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); // TODO(benoitjacob) make this a proper reference impl without Eigen! - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // TODO(benoitjacob) get rid of the dynamic memory allocation here! Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -3908,9 +3827,9 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, for (int y = 0; y < output_height; ++y) { for (int x = 0; x < output_width; ++x) { for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + output_data[Offset(output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], + output_data[Offset(output_shape, b, y, x, c)], output_activation_min, output_activation_max); } } @@ -3918,44 +3837,23 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const uint8* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -3975,11 +3873,12 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, uint16 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -4010,7 +3909,7 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ @@ -4051,54 +3950,23 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, +inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int kwidth, int kheight, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("MaxPool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Prefill the output to minimum representable float value out_mat.setConstant(std::numeric_limits::lowest()); for (int b = 0; b < batches; ++b) { @@ -4131,9 +3999,9 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, for (int y = 0; y < output_height; ++y) { for (int x = 0; x < output_width; ++x) { for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + output_data[Offset(output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], + output_data[Offset(output_shape, b, y, x, c)], output_activation_min, output_activation_max); } } @@ -4141,41 +4009,21 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -4193,11 +4041,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, uint8 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -4223,7 +4072,7 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON for (; channel <= depth - 16; channel += 16) { @@ -4250,53 +4099,23 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, +inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Pool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); // Actually carry out L2 Pool. Code is written in forward mode: we go through // the input values once, and write to all the pooled regions that it maps to. - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); Eigen::VectorXf in_square(in_mat.rows()); Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -4338,28 +4157,6 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); } -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - inline void LocalResponseNormalization(const float* input_data, const Dims<4>& input_dims, int range, float bias, float alpha, float beta, @@ -4405,14 +4202,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Softmax"); - MatchingFlatSize(input_dims, output_dims); + MatchingFlatSize(input_shape, output_shape); - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Compute the exponential first, removing the max coefficient for numerical // stability. out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; @@ -4424,10 +4221,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, out_mat.array().rowwise() *= scale; } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -4441,8 +4238,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPoint0 = gemmlowp::FixedPoint; gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int b = 0; b < outer_size; ++b) { const uint8* input_data_ptr = input_data + b * depth; @@ -4632,11 +4432,14 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, // TODO(myenik): This is the same as the reference implementation, not actually // optimized yet. -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const float* block_input_data = input_data + i * depth; @@ -4777,11 +4580,11 @@ log_x_for_x_greater_than_or_equal_to_1( } // Currently just a copy of the reference code. -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8"); // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as @@ -4796,8 +4599,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const uint8* block_input_data = input_data + i * depth; @@ -4861,21 +4667,21 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Uint8"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; #ifdef USE_NEON @@ -5007,10 +4813,10 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { } @@ -5067,21 +4873,21 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().tanh(); } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // Note that this is almost the exact same code as in Logistic(). gemmlowp::ScopedProfilingLabel label("Tanh"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; int32_t output_zero_point = 128; @@ -5222,16 +5028,16 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh/Int16"); // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); int c = 0; const int16* input_data_ptr = input_data; @@ -5322,49 +5128,6 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims, } } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Dequantize"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - int32 val = input_data[i]; - float result = static_cast(scale * (val - zero_point)); - output_data[i] = result; - } -} - -inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, - float rmin, float rmax, int num_bits, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("FakeQuant"); - - // 0 should always be a representable value. Let's assume that the initial - // min,max range contains 0. - TFLITE_DCHECK_LE(rmin, 0.0f); - TFLITE_DCHECK_GE(rmax, 0.0f); - TFLITE_DCHECK_LT(rmin, rmax); - - // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor. - int quant_min = 0; - int quant_max = (1 << num_bits) - 1; - float nudged_min, nudged_max, nudged_scale; - NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, - &nudged_max, &nudged_scale); - const float inv_nudged_scale = 1.0f / nudged_scale; - - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - const float src_val = input_data[i]; - const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[i] = dst_val; - } -} - template inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data, const Dims<4>& output_dims) { @@ -5382,26 +5145,6 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims, output_map.array() = Eigen::floor(input_map.array()); } -template -inline void Gather(const T* input_data, const Dims<4>& input_dims, - int input_rank, const int32* coords_data, - const Dims<4>& coords_dims, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Gather"); - - TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); - int stride = input_dims.strides[input_rank - 1]; - T* out = output_data; - - for (int i = 0; i < coords_dims.sizes[0]; i++) { - TFLITE_DCHECK_GE(coords_data[i], 0); - TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); - const T* in = input_data + coords_data[i] * stride; - memcpy(out, in, sizeof(T) * stride); - out += stride; - } -} - #ifdef USE_NEON inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, float scale, float* output_ptr) { @@ -5863,55 +5606,6 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, output_data, output_dims, /*align_corners=*/false); } -template -inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* paddings_data, - const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { - // Unoptimized - Straight copy from reference ops. - gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); - - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); - const int block_shape_height = block_shape_data[0]; - const int block_shape_width = block_shape_data[1]; - const int padding_top = paddings_data[0]; - const int padding_left = paddings_data[2]; - - for (int out_b = 0; out_b < output_batch_size; ++out_b) { - int input_batch = out_b % input_batch_size; - int shift_w = (out_b / input_batch_size) % block_shape_width; - int shift_h = (out_b / input_batch_size) / block_shape_width; - for (int out_h = 0; out_h < output_height; ++out_h) { - for (int out_w = 0; out_w < output_width; ++out_w) { - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height + shift_h < padding_top || - out_h * block_shape_height + shift_h >= - padding_top + input_height || - out_w * block_shape_width + shift_w < padding_left || - out_w * block_shape_width + shift_w >= padding_left + input_width) { - memset(out, 0, depth * sizeof(T)); - } else { - const T* in = - input_data + - Offset(input_dims, 0, - (out_w * block_shape_width + shift_w) - padding_left, - (out_h * block_shape_height + shift_h) - padding_top, - input_batch); - memcpy(out, in, depth * sizeof(T)); - } - } - } - } -} - // Helper methods for BatchToSpaceND. // `spatial_index_dim` specifies post-crop offset index in this spatial // dimension, i.e. spatial offset introduced by flattening batch to spatial @@ -6114,54 +5808,6 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, output_dims, 0); } -// UNOPTIMIZED COPY of StridedSlice from reference_ops.h. -template -inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, - const std::vector& start_indices, - const std::vector& stop_indices, - const std::vector& strides, T* output_data, - const Dims<4>& output_dims) { - TFLITE_DCHECK_EQ(start_indices.size(), 4); - TFLITE_DCHECK_EQ(stop_indices.size(), 4); - TFLITE_DCHECK_EQ(strides.size(), 4); - const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 3); - const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 3); - const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 2); - const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 2); - const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 1); - const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 1); - const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 0); - const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 0); - - T* out_ptr = output_data; - for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, strides[3]); - in_b += strides[3]) { - for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, strides[2]); - in_h += strides[2]) { - for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, strides[1]); - in_w += strides[1]) { - for (int in_d = start_d; - !strided_slice::LoopCondition(in_d, stop_d, strides[0]); - in_d += strides[0]) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; - } - } - } - } -} - template inline void Slice(const T* input_data, const Dims<4>& input_dims, const std::vector& begin, const std::vector& size, @@ -6196,41 +5842,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } -template -inline void Mean(const T* input_data, const Dims<4>& input_dims, - const std::vector& reduction_indices, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Mean"); - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); - - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - - // The current implementation only supports simultaneous reduction over - // width and height. - TFLITE_DCHECK_EQ(reduction_indices.size(), 2); - TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || - (reduction_indices[0] == 2 && reduction_indices[1] == 1)); - TFLITE_DCHECK_EQ(output_height, 1); - TFLITE_DCHECK_EQ(output_width, 1); - - for (int out_b = 0; out_b < output_batch; ++out_b) { - for (int out_d = 0; out_d < output_depth; ++out_d) { - float value = 0; - for (int in_h = 0; in_h < input_height; ++in_h) { - for (int in_w = 0; in_w < input_width; ++in_w) { - value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; - } - } - output_data[Offset(output_dims, out_d, 0, 0, out_b)] = - value / (input_width * input_height); - } - } -} - template void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, @@ -6310,67 +5921,6 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, output_map.array() = input1_map.array().max(max_value); } -template -void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("ArgMax"); - - // The current ArgMax implemention can only determine the index of the maximum - // value in the last dimension. So the axis argument is ignored. - - // For ArgMax, the number of output dimensions = (number of input dimensions - - // 1). For the sake of simplicity, the output dimensions are equal to the - // input dimensions here. We enforce the constraint that the last dimension - // must always be 1. - TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = ArraySize(input_dims, 0); - for (int i = 0; i < outer_size; ++i) { - auto max_value = *input_data; - ++input_data; - int max_index = 0; - for (int d = 1; d < depth; ++d) { - const auto& curr_value = *input_data; - if (curr_value > max_value) { - max_value = curr_value; - max_index = d; - } - ++input_data; - } - *output_data = max_index; - ++output_data; - } -} - -template -void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, const int* permuted_axes) { - int out_sizes[4]; - // Compute the inverse permutation array so we can do an output centered - // transpose. Also, check to make sure output_dims is matching input_dims. - for (int k = 0; k < 4; k++) { - out_sizes[k] = - MatchingArraySize(input_dims, permuted_axes[k], output_dims, k); - } - - // Naive transpose loop (iterate on output index and compute input index). - int o[4]; // loop index (on output). - int i[4]; - for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) { - i[permuted_axes[3]] = o[3]; - for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) { - i[permuted_axes[2]] = o[2]; - for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) { - i[permuted_axes[1]] = o[1]; - for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) { - i[permuted_axes[0]] = o[0]; - output[Offset(output_dims, o)] = input[Offset(input_dims, i)]; - } - } - } - } -} - template void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, const Dims<4>& filter_dims, int stride_width, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index 57ee859115cddbcbccae24ff639e848340d8e2ee..e224980493aa11f642da103ee7d7377b6c4b1da0 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include #include #include @@ -126,4 +127,16 @@ void NudgeQuantizationRange(const float min, const float max, *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); } +bool CheckedLog2(const float x, int* log2_result) { + // Using TfLiteRound instead of std::round and std::log instead of + // std::log2 to work around these fuctions being missing in a toolchain + // used in some TensorFlow tests as of May 2018. + const float x_log2 = std::log(x) * (1.0f / std::log(2.0f)); + const float x_log2_rounded = TfLiteRound(x_log2); + const float x_log2_fracpart = x_log2 - x_log2_rounded; + + *log2_result = static_cast(x_log2_rounded); + return std::abs(x_log2_fracpart) < 1e-3; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 182ee782c76fcccedc99327d47805b49bfb8580d..525857a2e6f73276d0a6e64770947169033c7667 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -218,6 +218,11 @@ void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, float* nudged_min, float* nudged_max, float* scale); +// If x is approximately a power of two (with any positive or negative +// exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise +// returns false. +bool CheckedLog2(const float x, int* log2_result); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h index 6f5f6a3e6fa905f594c0361b163b5b817306dafc..878b2441b4f2828a014673f5bd80fb8aa29514db 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -34,15 +34,297 @@ inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { template void L2Normalization(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { - return L2Normalization(input_data, DimsToShape(input_dims), output_data, - DimsToShape(output_dims)); + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); } inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, int32 input_zero_point, uint8* output_data, const Dims<4>& output_dims) { - return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, - output_data, DimsToShape(output_dims)); + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu1(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu6(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); } } // namespace reference_ops diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 6cef94a606aa6fbc39f3105b9b7aca1af4092970..9357e7407eb83fe8ea3486dfdde8742fc6323ee9 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -697,7 +697,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } } -inline void ExperimentalShuffledFullyConnected( +inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, @@ -914,9 +914,9 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float lower = 0; @@ -925,9 +925,10 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu1(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 1; @@ -937,9 +938,10 @@ inline void Relu1(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu6(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 6; @@ -949,6 +951,19 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, } } +inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, + const RuntimeShape& input_shape, uint8* output_data, + const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const uint8 val = input_data[i]; + const uint8 clamped = + val > max_value ? max_value : val < min_value ? min_value : val; + output_data[i] = clamped; + } +} + template void L2Normalization(const float* input_data, const RuntimeShape& input_shape, float* output_data, const RuntimeShape& output_shape) { @@ -1049,10 +1064,11 @@ inline void L2Normalization(const uint8* input_data, } } -inline void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +template +inline void Add(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( @@ -1134,22 +1150,12 @@ inline void Add(int left_shift, const uint8* input1_data, } } -template inline void Add(const int16* input1_data, const Dims<4>& input1_dims, int input1_shift, const int16* input2_data, const Dims<4>& input2_dims, int input2_shift, int16 output_activation_min, int16 output_activation_max, int16* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); @@ -1175,6 +1181,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, + input2_shift, output_activation_min, output_activation_max, output_data, + output_dims); +} + // TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -1755,7 +1783,6 @@ template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, Scalar* output_data, const Dims<4>& output_dims) { - TFLITE_DCHECK_GT(inputs_count, 1); int concat_size = 0; for (int i = 0; i < inputs_count; i++) { for (int j = 0; j < 4; j++) { @@ -1766,7 +1793,9 @@ void Concatenation(int concat_dim, const Scalar* const* input_data, concat_size += ArraySize(*input_dims[i], concat_dim); } TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // For now we don't have a model with a Concatenation with fused activation. + TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone); int outer_size = 1; for (int i = concat_dim + 1; i < 4; i++) { outer_size *= output_dims.sizes[i]; @@ -2244,18 +2273,21 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const float* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float output_activation_min, float output_activation_max, float* output_data, - const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2279,12 +2311,12 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; total += - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } const float average = total / filter_count; - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(average, output_activation_min, output_activation_max); } @@ -2293,42 +2325,22 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const uint8* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2351,14 +2363,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, ++filter_x) { const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; - acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + acc += + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } acc = (acc + filter_count / 2) / filter_count; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(acc); } } @@ -2366,50 +2379,19 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, +inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + float* output_data, const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2433,13 +2415,13 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; const float val = - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; sum_squares += val * val; filter_count++; } } const float l2pool_result = std::sqrt(sum_squares / filter_count); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(l2pool_result, output_activation_min, output_activation_max); } @@ -2448,40 +2430,19 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, +inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + float* output_data, const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2505,10 +2466,10 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(max, output_activation_min, output_activation_max); } @@ -2517,42 +2478,22 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { TFLITE_DCHECK_LE(output_activation_min, output_activation_max); TFLITE_DCHECK_GE(output_activation_min, 0); TFLITE_DCHECK_LE(output_activation_max, 255); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2576,12 +2517,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } max = std::max(max, output_activation_min); max = std::min(max, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(max); } } @@ -2589,38 +2530,6 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - inline void LocalResponseNormalization(const float* input_data, const Dims<4>& input_dims, int range, float bias, float alpha, float beta, @@ -2644,11 +2553,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2673,10 +2585,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, } } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2689,8 +2601,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2751,10 +2666,13 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2894,11 +2812,11 @@ log_x_for_x_greater_than_or_equal_to_1( input_val); } -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2912,8 +2830,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2977,9 +2898,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -2988,11 +2909,11 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); + uint8* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -3026,9 +2947,9 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -3044,9 +2965,9 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -3055,12 +2976,12 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { const int32 output_zero_point = 128; - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -3095,15 +3016,15 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); // F0 uses 0 integer bits, range [-1, 1]. // This is the return type of math functions such as tanh, logistic, @@ -3435,7 +3356,7 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, template inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, + int begin_mask, int end_mask, int shrink_axis_mask, const std::vector& start_indices, const std::vector& stop_indices, const std::vector& strides, T* output_data, @@ -3447,20 +3368,24 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, TFLITE_DCHECK_EQ(strides.size(), 4); const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 3); - const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 3); + const int stop_b = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 3, start_b); const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 2); - const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 2); + const int stop_h = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 2, start_h); const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 1); - const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 1); + const int stop_w = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 1, start_w); const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 0); - const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 0); + const int stop_d = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 0, start_d); T* out_ptr = output_data; for (int in_b = start_b; @@ -3523,8 +3448,6 @@ inline void Exp(const T* input_data, const size_t num_elements, } // A generic reduce method that can be used for reduce_sum, reduce_mean, etc. -// It takes a reducer function as input and returns false when numeric overflow -// is detected. // This method iterates through input data and reduce elements along the // dimensions given in axis. template @@ -3532,8 +3455,7 @@ inline bool Reduce(const In* input_data, const int* input_dims, const int* output_dims, const int input_num_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, - Out reducer(Out current, const In in, bool* overflow), - Out* output_data) { + Out reducer(Out current, const In in), Out* output_data) { // Reset input iterator. TFLITE_DCHECK(input_num_dims > 0); for (int idx = 0; idx < input_num_dims; ++idx) { @@ -3545,10 +3467,8 @@ inline bool Reduce(const In* input_data, const int* input_dims, ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr); size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims, input_iter, num_axis, axis); - bool overflow = false; - output_data[output_offset] = reducer(output_data[output_offset], - input_data[input_offset], &overflow); - if (overflow) return false; + output_data[output_offset] = + reducer(output_data[output_offset], input_data[input_offset]); } while (NextIndex(input_num_dims, input_dims, input_iter)); return true; } @@ -3583,7 +3503,7 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, Out* output_data) { - auto reducer = [](Out current, const In in, bool* overflow) -> Out { + auto reducer = [](Out current, const In in) -> Out { const Out actual_in = static_cast(in); return current + actual_in; }; @@ -3592,6 +3512,39 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, output_data); } +// Computes the sum of elements across dimensions given in axis. +template +inline bool Sum(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis) { + // Reset output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = T(); + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + return ReduceSumImpl(input_data, input_dims, output_dims, + input_num_dims, output_num_dims, resolved_axis, + num_resolved_axis, temp_index, output_data); +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis. @@ -3794,7 +3747,7 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, template void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, int* permuted_axes) { + const Dims<4>& output_dims, const int* permuted_axes) { int out_sizes[4]; // Compute the inverse permutation array so we can do an output centered // transpose. Also, check to make sure output_dims is matching input_dims. @@ -3844,7 +3797,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, // computing their influence on the output, rather than looping through the // output elements in the typical "gather" access pattern of a conv. We // therefore must initialize the output array to zero. - for (int i = 0; i < FlatSize(output_dims); i++) { + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; i++) { output_data[i] = 0.0f; } @@ -4133,6 +4087,36 @@ inline void SparseToDense(const std::vector>& indices, } } +template +inline void Pow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = std::pow(input1_data[i], input2_data[i]); + } +} + +template +inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc index d781a7b642036f3c5ddaa366f257fe26511c83c3..a7dad3c14e60fac9da9c0bcfd5d1d4c8f10b71c7 100644 --- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc @@ -32,19 +32,21 @@ namespace tflite { namespace { void RunSoftmaxFloatReference(const uint8* input_data, - const Dims<4>& dims_common, int32 input_offset, - const double input_scale, int stride, float beta, + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, uint8* reference_output_data) { - const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); // Reference data generated via Dequant of input into float, and then applying // float Softmax. - reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, - reference_dequant_data.data(), dims_common); - optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta, - reference_output_float_data.data(), dims_common); + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta, + reference_output_float_data.data(), shape_common); // Work with quantized scaling for Softmax, under which 256 represents 1, but // we limit this to 255. for (int i = 0; i < ref_buffer_size; i++) { @@ -55,9 +57,9 @@ void RunSoftmaxFloatReference(const uint8* input_data, } void CheckOutputData(const uint8* test_output, const uint8* reference_output, - const Dims<4>& dims_common, const string& check_label, - bool be_exacting) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); // While calculating some metrics in floating point, we work with quantized // scaling. std::vector diff(buffer_size); @@ -91,15 +93,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output, // Runs the Softmax and compares against the float reference implementation and // the quantized reference implementation. -void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, - int32 input_offset, const double input_scale, int stride, - float beta) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); +void RunOneSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); std::vector optimized_softmax_output(buffer_size); std::vector reference_float_softmax_output(buffer_size); std::vector reference_quant_softmax_output(buffer_size); - RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale, + RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_softmax_output.data()); int32 input_beta_multiplier; @@ -113,21 +115,21 @@ void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_beta_left_shift); - optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier, + optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, diff_min, - optimized_softmax_output.data(), dims_common); - reference_ops::Softmax(input_data, dims_common, input_beta_multiplier, + optimized_softmax_output.data(), shape_common); + reference_ops::Softmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, diff_min, - reference_quant_softmax_output.data(), dims_common); + reference_quant_softmax_output.data(), shape_common); CheckOutputData(optimized_softmax_output.data(), - reference_float_softmax_output.data(), dims_common, + reference_float_softmax_output.data(), shape_common, "Optimized vs float reference", false); CheckOutputData(optimized_softmax_output.data(), - reference_quant_softmax_output.data(), dims_common, + reference_quant_softmax_output.data(), shape_common, "Optimized vs quant reference", true); CheckOutputData(reference_quant_softmax_output.data(), - reference_float_softmax_output.data(), dims_common, + reference_float_softmax_output.data(), shape_common, "Quant reference vs float reference", false); } @@ -150,13 +152,13 @@ bool TryOneUniformSoftmax() { const int32 input_offset = UniformRandomInt(-256, 0); const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandom(&input_data); - RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } @@ -188,14 +190,14 @@ bool TryOneSkyscraperSoftmax(bool small_depth) { const int middle_min = UniformRandomInt(0, 255); const int sides_max = UniformRandomInt(0, middle_min); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); - RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h index ef77371bf65cc975dfa35275c8daa32de112a249..5994fad5c73df1dde6e33ba46dbd6e0802ea61be 100644 --- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h @@ -74,12 +74,22 @@ inline int StartForAxis(int begin_mask, // size 4, this function would return 4 as the stop, because it is one past the // "real" indices of 0, 1, 2 & 3. template -inline int StopForAxis(int end_mask, std::vector const& stop_indices, +inline int StopForAxis(int end_mask, int shrink_axis_mask, + std::vector const& stop_indices, std::vector const& strides, - int const* input_shape, int axis) { + int const* input_shape, int axis, int start_for_axis) { // Begin with the specified index + const bool shrink_axis = shrink_axis_mask & (1 << axis); int stop = stop_indices[axis]; + // When shrinking an axis, the end position does not matter (and can be + // incorrect when negative indexing is used, see Issue #19260). Always use + // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has + // already been adjusted for negative indices. + if (shrink_axis) { + stop = start_for_axis + 1; + } + // end_mask override if (end_mask & (1 << axis)) { if (strides[axis] > 0) { @@ -93,7 +103,7 @@ inline int StopForAxis(int end_mask, std::vector const& stop_indices, } // Handle negative indices - int axis_size = input_shape[axis]; + const int axis_size = input_shape[axis]; if (stop < 0) { stop += axis_size; } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index 518bee1c6369d3ce93d1b98e19dba7615b5844dc..ee2af5b46046c9e8bdc5816d5b6e9e9100cdc240 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#include #include #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" @@ -54,6 +55,13 @@ inline bool* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline std::complex* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast*>(tensor->data.c64) + : nullptr; +} + template inline const T* GetTensorData(const TfLiteTensor* tensor); @@ -87,6 +95,13 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline const std::complex* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast*>(tensor->data.c64) + : nullptr; +} + inline int RemapDim(int max_dimensions, int d) { return max_dimensions - d - 1; } diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 64f4881a4686525fa6b56c30c1411fe5c91334b2..fa2420713fea4faa3596251a95c2ed9606878b98 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -25,6 +25,67 @@ namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; enum class PaddingType { kNone, kSame, kValid }; +// This enumeration allows for non-default formats for the weights array +// of a fully-connected operator, allowing the use of special optimized +// runtime paths. +enum class FullyConnectedWeightsFormat : uint8 { + // Default format (flat 2D layout, the inner contiguous dimension + // is input_depth, the outer non-contiguous dimension is output_depth) + kDefault, + // Summary: optimized layout for fast CPU runtime implementation, + // aimed specifically at ARM CPUs at the moment, and specialized for + // 8-bit quantized layers. + // + // The use case we're concerned with here is: 8-bit quantization, + // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in + // a key application that drove this), very small batch size (e.g. 1 -- 4). + // + // Even with 8-bit quantization of weights, the performance of memory + // accesses to the weights can become the dominant issue when + // the batch size is small, so each weight value is used in only a few + // arithmetic ops, i.e. the fully-connected node has a low arithmetic + // intensity. The specific issues that arise are of three kinds: + // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory + // bound. That's the "good" issue to run into. + // (2) One may run into sub-optimal pre-fetching: the data hasn't been + // prefetched into the cache by the time we need it. + // (3) One may run into cache aliasing: multiple values that are + // pre-fetched, alias each other in the L1 cache (which typically + // has only 4-way set associativity in ARM CPUs) and thus evict + // each other before we get to using them. + // + // The point of this shuffling is to avoid issues (2) and (3) so that + // we get as fast as possible given only the hard constraint (1). + // This is achieved by turning the difficulty into a solution: the + // difficulty, that each value loaded from memory is used only in + // one kernel iteration, making this operation memory-intensive, hints at + // the solution, of shuffling the weights so that they are stored in the + // exact order as the kernel needs to load them, so that the memory + // accesses made by the kernel are trivial. This solves (2) because the + // trivial memory access pattern allows the CPU's automatic prefetching + // to perform very well (no need even for preload instructions), and this + // solves (3) because the values being loaded concurrently are now + // contiguous in the address space, thus don't alias each other in the cache. + // + // On ARM, we typically want our kernel to process a 4x16 block of weights + // at a time, because: + // - 16 is the number of bytes in a NEON register. + // - 4 is how many rows we need to handle concurrently in the kernel in + // order to have sufficient mutual independence of instructions to + // maximize arithmetic throughput. + // + // Finally, the 'Int8' part in the name refers to the fact that this + // weights format has each weights value encoded as a signed int8 value, + // even if the data type of the weights buffer is uint8. This is intended + // to save runtime kernels the effort to have to XOR the top bit of these + // bytes before using them in signed arithmetic, see this file for more + // explanations on the 'signed int8 trick' in matrix multiplication kernels: + // + // tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc + // + kShuffled4x16Int8, +}; + // Quantization parameters, determining the mapping of quantized values // to real values (i.e. determining how quantized values are mathematically // interpreted). @@ -294,6 +355,50 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { return FlatSize(dims); } +// Flat size calculation, checking that dimensions match with one or more other +// arrays. +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return shape.FlatSize(); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3); +} + // Flat size calculation, checking that dimensions match with one or more other // arrays. template @@ -320,7 +425,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2); + return MatchingFlatSize(dims, check_dims_1, check_dims_2); } template @@ -331,7 +436,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3); + return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3); } // Data is required to be contiguous, and so many operators can use either the diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index 184028427fb193aa99cf155961c16eda1298e326..08f942c933552aa6ca7369550c928efba9e2e93e 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -43,12 +43,11 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, return kTfLiteOk; } -void CalculateActivationRangeUint8(TfLiteFusedActivation activation, - TfLiteTensor* output, int32_t* act_min, - int32_t* act_max) { - const int32_t qmin = std::numeric_limits::min(); - const int32_t qmax = std::numeric_limits::max(); - +namespace { +void CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation, + int32_t qmin, int32_t qmax, + TfLiteTensor* output, + int32_t* act_min, int32_t* act_max) { const auto scale = output->params.scale; const auto zero_point = output->params.zero_point; @@ -70,23 +69,38 @@ void CalculateActivationRangeUint8(TfLiteFusedActivation activation, *act_max = qmax; } } - -void CalculateActivationRangeFloat(TfLiteFusedActivation activation, - float* activation_min, - float* activation_max) { - if (activation == kTfLiteActRelu) { - *activation_min = 0.f; - *activation_max = std::numeric_limits::max(); - } else if (activation == kTfLiteActRelu6) { - *activation_min = 0.f; - *activation_max = 6.f; - } else if (activation == kTfLiteActRelu1) { - *activation_min = -1.f; - *activation_max = 1.f; +} // namespace + +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max) { + int32_t qmin = 0; + int32_t qmax = 0; + if (output->type == kTfLiteUInt8) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); + } else if (output->type == kTfLiteInt16) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); } else { - *activation_min = std::numeric_limits::lowest(); - *activation_max = std::numeric_limits::max(); + TF_LITE_ENSURE(context, false); } + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); } bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) { diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 82cded36f2ed2777daccafee5890f47c0d7254e8..c8ce3c917d5bf66e01fbae95c18dfe97b3c84bae 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#include + #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" @@ -86,14 +88,35 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, TfLiteTensor* output, double* multiplier); -// Calculates the useful range of an activation layer given its activation -// tensor. +// Calculates the useful quantized range of an activation layer given its +// activation tensor. +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max); void CalculateActivationRangeUint8(TfLiteFusedActivation activation, TfLiteTensor* output, int32_t* act_min, int32_t* act_max); -void CalculateActivationRangeFloat(TfLiteFusedActivation activation, - float* activation_min, - float* activation_max); +// Calculates the useful range of an activation layer given its activation +// tensor.a +template +void CalculateActivationRange(TfLiteFusedActivation activation, + T* activation_min, T* activation_max) { + if (activation == kTfLiteActRelu) { + *activation_min = 0; + *activation_max = std::numeric_limits::max(); + } else if (activation == kTfLiteActRelu6) { + *activation_min = 0; + *activation_max = 6; + } else if (activation == kTfLiteActRelu1) { + *activation_min = -1; + *activation_max = 1; + } else { + *activation_min = std::numeric_limits::lowest(); + *activation_max = std::numeric_limits::max(); + } +} // Return true if the given tensors have the same shape. bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2); diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc index 62820a2f5113cb6ae252386aaf3842135383b79f..9a8d35e82cbc3a7e55246e6c06599b2838d1ee67 100644 --- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc @@ -90,10 +90,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::LogSoftmax(input_buffer, input_dims, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::LogSoftmax(input_buffer, input_shape, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index eb26a02455ce2afccaa081a72d93a9ceeca746cc..3577ae6caa1e02ce2e5db2e8054ba9c2fccbe93e 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" @@ -37,14 +38,17 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (18-inputs) or basic kernel - // (5-inputs). + // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // (5 inputs). TfLiteLSTMKernelType kernel_type; - // Only used by full kernel. + + // These fields are only used by full kernel. + int activation_state_tensor_index; + int cell_state_tensor_index; int scratch_tensor_index; }; -// For full inputs kernel (18-inputs). +// For full inputs kernel (18 or 20 inputs). namespace full { // Input Tensors of size {n_batch, n_input} @@ -78,7 +82,16 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional +// If the node has 20 inputs, the following 2 tensors are used as state tensors. +// These are defined as variable tensors, and will be modified by this op. +constexpr int kInputActivationStateTensor = 18; +constexpr int kInputCellStateTensor = 19; + // Output tensors. +// * If the node has 18 inputs, these 2 tensors are used as state tensors. +// * If the node has 20 inputs, these 2 tensors are ignored. +// TODO(ycling): Make the 2 output state tensors optional, and propagate the +// state to output tensors when the 2 tensors present. constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; @@ -246,10 +259,31 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast(node->user_data); - // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); + // True if the node is using input variable state tensors. It means: + // * The state tensors are defined as inputs. In this case it would be the + // 19th and 20th input tensors. + // * Otherwise, the output tensors are used to store states. + bool use_input_variable_states; + if (node->inputs->size == 20) { + use_input_variable_states = true; + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = + node->inputs->data[kInputCellStateTensor]; + } else if (node->inputs->size == 18) { + use_input_variable_states = false; + op_data->activation_state_tensor_index = + node->outputs->data[kOutputStateTensor]; + op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; + } else { + context->ReportError( + context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } + // Inferring batch size, number of outputs and number of cells from the // input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -274,34 +308,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check that input tensor dimensions matches with each other. CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); - // Get the pointer to output, output_state and cell_state tensors. + // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - // Resize the output, output_state and cell_state tensors. + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + + if (use_input_variable_states) { + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), + n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + } else { + // If the state tensors are outputs, this function takes the + // responsibility to resize the state tensors. + TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); + activation_state_size->data[0] = n_batch; + activation_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, + activation_state_size)); + + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + // Mark state tensors as persistent tensors. + activation_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + } + + // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); output_size->data[0] = n_batch; output_size->data[1] = n_output; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); - TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); - output_state_size->data[0] = n_batch; - output_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, output_state, output_state_size)); - - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); - - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - // The weights are of consistent type, so it suffices to check one. // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && @@ -337,7 +384,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (is_hybrid_op) { // Allocate temporary tensors to store quantized values of input, - // output_state and cell_state tensors. + // activation_state and cell_state tensors. node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); input_quantized->type = kTfLiteUInt8; @@ -348,17 +395,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_quantized_size)); } node->temporaries->data[2] = op_data->scratch_tensor_index + 2; - TfLiteTensor* output_state_quantized = + TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); - output_state_quantized->type = kTfLiteUInt8; - output_state_quantized->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqual(output_state_quantized->dims, - output_state->dims)) { - TfLiteIntArray* output_state_quantized_size = - TfLiteIntArrayCopy(output_state->dims); - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, output_state_quantized, - output_state_quantized_size)); + activation_state_quantized->type = kTfLiteUInt8; + activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(activation_state_quantized->dims, + activation_state->dims)) { + TfLiteIntArray* activation_state_quantized_size = + TfLiteIntArrayCopy(activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation_state_quantized, + activation_state_quantized_size)); } node->temporaries->data[3] = op_data->scratch_tensor_index + 3; TfLiteTensor* cell_state_quantized = @@ -438,7 +485,7 @@ TfLiteStatus EvalFloat( const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* activation_state, TfLiteTensor* cell_state, TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; @@ -499,7 +546,7 @@ TfLiteStatus EvalFloat( const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; @@ -512,8 +559,8 @@ TfLiteStatus EvalFloat( cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_batch); + activation_state_ptr, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); return kTfLiteOk; } @@ -536,9 +583,9 @@ TfLiteStatus EvalHybrid( const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, - TfLiteTensor* output_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { + TfLiteTensor* activation_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -639,15 +686,15 @@ TfLiteStatus EvalHybrid( const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; // Temporary storage for quantized values and scaling factors. int8_t* quantized_input_ptr = reinterpret_cast(input_quantized->data.uint8); - int8_t* quantized_output_state_ptr = - reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_activation_state_ptr = + reinterpret_cast(activation_state_quantized->data.uint8); int8_t* quantized_cell_state_ptr = reinterpret_cast(cell_state_quantized->data.uint8); float* scaling_factors_ptr = scaling_factors->data.f; @@ -672,14 +719,16 @@ TfLiteStatus EvalHybrid( input_gate_scratch, forget_gate_scratch, cell_scratch, output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, recovered_cell_weights_ptr, quantized_input_ptr, - quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, - cell_state_ptr, output_ptr_batch); + quantized_activation_state_ptr, quantized_cell_state_ptr, + activation_state_ptr, cell_state_ptr, output_ptr_batch); return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_to_input_weights = @@ -723,8 +772,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // TODO(mirkov): add a check that weights are all uint8s or all floats. @@ -738,11 +790,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { cell_to_output_weights, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, - scratch_buffer, output_state, cell_state, output); + scratch_buffer, activation_state, cell_state, output); } case kTfLiteUInt8: { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* output_state_quantized = + TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); TfLiteTensor* cell_state_quantized = GetTemporary(context, node, /*index=*/3); @@ -760,8 +812,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, output_state_quantized, cell_state_quantized, - output_state, cell_state, output); + input_quantized, activation_state_quantized, cell_state_quantized, + activation_state, cell_state, output); } default: context->ReportError(context, "Type %d is not currently supported.", @@ -805,13 +857,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, node->inputs->size == kInputNum); TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); - // Only Float32 is supported currently. - // TODO(ycling): Implement quantize uint8 support. - for (int index = 0; index < node->inputs->size; ++index) { - TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; - TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32); - } - const TfLiteTensor* input = GetInput(context, node, kInputData); const TfLiteTensor* prev_activation = GetInput(context, node, kInputPrevActivation); @@ -821,15 +866,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->dims->size, 2); const int num_batches = input->dims->data[0]; + const int input_depth = input->dims->data[1]; TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); + const int activation_depth = prev_activation->dims->data[1]; + const int total_depth = input_depth + activation_depth; TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth); + TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth); + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth); TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth); TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); TfLiteTensor* state_out = GetOutput(context, node, kOutputState); @@ -843,14 +896,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, state_out, TfLiteIntArrayCopy(prev_state->dims))); + TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); concat_temp_size->data[0] = num_batches; - concat_temp_size->data[1] = weights->dims->data[1]; + concat_temp_size->data[1] = total_depth; TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, concat_temp, concat_temp_size)); TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); activation_temp_size->data[0] = num_batches; - activation_temp_size->data[1] = weights->dims->data[0]; + activation_temp_size->data[1] = 4 * activation_depth; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, activation_temp_size)); @@ -876,18 +930,73 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* activation_temp = GetOutput(context, node, kOutputActivationTemp); - optimized_ops::LstmCell( - // Inputs. - GetTensorData(input), GetTensorDims(input), - GetTensorData(prev_activation), GetTensorDims(prev_activation), - GetTensorData(weights), GetTensorDims(weights), - GetTensorData(bias), GetTensorDims(bias), - GetTensorData(prev_state), GetTensorDims(prev_state), - // Outputs. - GetTensorData(state_out), GetTensorDims(state_out), - GetTensorData(activation_out), GetTensorDims(activation_out), - GetTensorData(concat_temp), GetTensorDims(concat_temp), - GetTensorData(activation_temp), GetTensorDims(activation_temp)); + if (input->type == kTfLiteFloat32 && + prev_activation->type == kTfLiteFloat32 && + weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 && + prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 && + activation_out->type == kTfLiteFloat32 && + concat_temp->type == kTfLiteFloat32 && + activation_temp->type == kTfLiteFloat32) { + optimized_ops::LstmCell( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp)); + } else if (input->type == kTfLiteUInt8 && + prev_activation->type == kTfLiteUInt8 && + weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 && + prev_state->type == kTfLiteInt16 && + state_out->type == kTfLiteInt16 && + activation_out->type == kTfLiteUInt8 && + concat_temp->type == kTfLiteUInt8 && + activation_temp->type == kTfLiteInt16) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + int state_scale_log2_rounded; + if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) { + context->ReportError( + context, + "The internal state of a LSTM cell must have a power-of-two scale."); + return kTfLiteError; + } + const int state_integer_bits = 15 + state_scale_log2_rounded; + if (state_integer_bits != 4) { + context->ReportError(context, + "The only case of quantized LstmCell currently " + "supported is with StateIntegerBits==4"); + return kTfLiteError; + } + + double real_accum_multiplier = 4096 * bias->params.scale; + int32 accum_multiplier; + int accum_shift; + tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier, + &accum_shift); + optimized_ops::LstmCell<4>( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp), + weights->params.zero_point, accum_multiplier, accum_shift, + gemm_context); + } else { + context->ReportError(context, + "Unsupported combination of data types for LstmCell"); + return kTfLiteError; + } // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs // LSTM kernel. @@ -901,6 +1010,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace basic void* Init(TfLiteContext* context, const char* buffer, size_t length) { + gemm_support::IncrementUsageCounter(context); + const auto* params = reinterpret_cast(buffer); switch (params->kernel_type) { case kTfLiteLSTMFullKernel: @@ -910,6 +1021,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } } void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); } diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 6da29a4a923f16f7b5ad382f51cfd820783504cd..0b7c56133e3cbb3d85f75657b6141620a8019e61 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -97,6 +97,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -227,6 +233,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; @@ -352,14 +360,6 @@ class BaseLstmTest : public ::testing::Test { } EXPECT_THAT(lstm->GetOutput(), ElementsAreArray(ArrayFloatNear(expected, tolerance))); - for (int i = 0; i < num_outputs; ++i) { - std::cout << lstm->GetOutput()[i] << ", "; - } - std::cout << std::endl; - for (int i = 0; i < num_outputs; ++i) { - std::cout << expected[i] << ", "; - } - std::cout << std::endl; } } }; diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc index 0752aa1804722accb1f88910fe013ffd632a4503..fd4d5367c5a6369b5ffeeea30a910262bc0796a9 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc +++ b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc @@ -126,10 +126,10 @@ TEST(MaximumOpTest, FloatWithBroadcastTest) { TEST(MaximumOpTest, Int32WithBroadcastTest) { std::initializer_list data1 = {1, 0, -1, -2, 3, 11}; std::initializer_list data2 = {2}; - TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}}, + TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}}, {TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}}, data1, data2, {2, 2, 2, 2, 3, 11}); - TestModel(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}}, + TestModel(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}}, {TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}}, data1, data2, {1, 0, -1, -2, 2, 2}); } diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index b69a221447db963bcd3a7e6a69f132fe3767bfd1..1f72f3a3c7af4f9e042c9b2ac09252fab5de1a4f 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // Parameters used in the quantized paths where the output is 8bit + int32 output_activation_min; + int32 output_activation_max; + + // Parameters used in all quantized paths + int32_t output_multiplier; + int output_shift; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); - output->type = input2->type; data->requires_broadcast = !HaveSameShapes(input1, input2); @@ -74,6 +82,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; + } + return context->ResizeTensor(context, output, output_size); } @@ -83,8 +105,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_MUL(type, opname) \ type::opname(GetTensorData(input1), GetTensorDims(input1), \ GetTensorData(input2), GetTensorDims(input2), \ @@ -107,42 +129,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, } template -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - - int32_t output_multiplier; - int output_shift; - - double real_multiplier = - input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier, - &output_shift); - output_shift *= -1; - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_MUL(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { + if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + -input1->params.zero_point, GetTensorData(input2), \ + GetTensorDims(input2), -input2->params.zero_point, \ + output->params.zero_point, data->output_multiplier, \ + data->output_shift, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ GetTensorDims(output)); - // The quantized version of Mul doesn't support activations, so we - // always use BroadcastMul. - if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops, BroadcastMul); + // The quantized version of Mul doesn't support activations, so we + // always use BroadcastMul. + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, BroadcastMul); + } else { + TF_LITE_MUL(optimized_ops, BroadcastMul); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteInt16) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + GetTensorData(output), GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output->params.zero_point, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL } else { - TF_LITE_MUL(optimized_ops, BroadcastMul); + context->ReportError( + context, "Unsupported combination of input and output types in Mul."); + return kTfLiteError; } -#undef TF_LITE_MUL + return kTfLiteOk; } template @@ -156,12 +196,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { EvalFloat(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, data, input1, input2, - output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK( + context, EvalQuantized(context, node, params, data, input1, + input2, output)); } else { context->ReportError( - context, "Mul only supports FLOAT32 and quantized UINT8 now, got %d.", + context, + "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index f1a30f82634631ba8320421d5b36ffe446f443fa..43d56e50d2686ff2624f36a0c5d8e43279a572cc 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -58,6 +58,9 @@ class FloatMulOpModel : public BaseMulOpModel { const float kQuantizedStep = 2.0 / 255.0; const float kQuantizedTolerance = 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; +const float kQuantizedStepInt16 = 2.0 / 32767.0; +const float kQuantizedToleranceInt16 = + 2.0 * kQuantizedStepInt16 + kQuantizedStepInt16 * kQuantizedStepInt16; class QuantizedMulOpModel : public BaseMulOpModel { public: @@ -67,6 +70,11 @@ class QuantizedMulOpModel : public BaseMulOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; TEST(FloatMulOpTest, NoActivation) { @@ -138,6 +146,38 @@ TEST(QuantizedMulOpTest, NoActivation) { kQuantizedTolerance))); } +TEST(QuantizedMulOpTest, NoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedToleranceInt16))); +} + +TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) { + const float kMinInt16 = -1.f; + const float kMaxInt16 = 32767.f / 32768.f; + const float kMinUint8 = -1.f; + const float kMaxUint8 = 127.f / 128.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_UINT8, {}, kMinUint8, kMaxUint8}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + // for quantized Mul, the error shouldn't exceed 2*step float GetTolerance(int min, int max) { float kQuantizedStep = (max - min) / 255.0; diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/contrib/lite/kernels/neg_test.cc index 3c95ac8cc2727fdeff5f39aa2fe30eb6129a6022..3d3594c60bbe1684dff7b1816f5f8a715b1abc60 100644 --- a/tensorflow/contrib/lite/kernels/neg_test.cc +++ b/tensorflow/contrib/lite/kernels/neg_test.cc @@ -58,9 +58,9 @@ TEST(NegOpModel, NegFloat) { TEST(NegOpModel, NegInt32) { NegOpModel m({TensorType_INT32, {2, 3}}, {TensorType_INT32, {2, 3}}); - m.SetInput({-2, -1, 0, 1, 2, 3}); + m.SetInput({-2, -1, 0, 1, 2, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); } TEST(NegOpModel, NegInt64) { diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index bcad58406af1cdd466e410a06011641692194be4..1c728a473326564a85a5e7d3d72718265979e29a 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -95,6 +95,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -228,6 +234,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 83668cb4ca87e9eb53ab4ba9e88f91e3315594de..4be8c243c17c533e8c7d5aa7bb50c9d790b06995 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -128,7 +128,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(nupurgarg): Change kernel implementation to use padding arrays in // forward order (depth, width, height, batch). // Build paddings in order of int[] = {batch, height, width, depth} to match - // kernel implementation of Pad in referenced_ops.h and optimized_ops.h. + // kernel implementation of Pad in reference_ops.h and optimized_ops.h. for (int idx = op_context.dims - 1; idx >= 0; --idx) { before_padding.push_back(paddings_data[idx * 2]); after_padding.push_back(paddings_data[idx * 2 + 1]); diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 311e9b8399726d758182e1f084a890d6f10e57ce..7240fe04ccdadfb7b9703c3f2775c4b3502bd1d9 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -80,24 +80,24 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { // Matching GetWindowedOutputSize in TensorFlow. auto padding = params->padding; - auto computeOutSize = [padding](int imageSize, int filterSize, - int stride) -> int { + auto compute_out_size = [padding](int image_size, int filter_size, + int stride) -> int { return padding == kTfLitePaddingSame - ? (imageSize + stride - 1) / stride + ? (image_size + stride - 1) / stride : padding == kTfLitePaddingValid - ? (imageSize - filterSize + stride) / stride + ? (image_size - filter_size + stride) / stride : 0; }; - int outWidth = - computeOutSize(width, params->filter_width, params->stride_width); - int outHeight = - computeOutSize(height, params->filter_height, params->stride_height); + int out_width = + compute_out_size(width, params->filter_width, params->stride_width); + int out_height = + compute_out_size(height, params->filter_height, params->stride_height); data->padding.height = ComputePadding(params->stride_height, 1, height, - params->filter_height, outHeight); + params->filter_height, out_height); data->padding.width = ComputePadding(params->stride_width, 1, width, - params->filter_width, outWidth); + params->filter_width, out_width); if (input->type == kTfLiteUInt8) { if (pool_type == kAverage || pool_type == kMax) { @@ -111,12 +111,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { } } - TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); - outputSize->data[0] = batches; - outputSize->data[1] = outHeight; - outputSize->data[2] = outWidth; - outputSize->data[3] = channels_out; - return context->ResizeTensor(context, output, outputSize); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = out_height; + output_size->data[2] = out_width; + output_size->data[3] = channels_out; + return context->ResizeTensor(context, output, output_size); } template @@ -124,14 +124,15 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -148,13 +149,13 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool(GetTensorData(input), GetTensorDims(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, \ - activation_min, activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -168,14 +169,15 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_MAX_POOL(type) \ - type::MaxPool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -193,12 +195,12 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); #define TF_LITE_MAX_POOL(type) \ - type::MaxPool(GetTensorData(input), GetTensorDims(input), \ + type::MaxPool(GetTensorData(input), GetTensorShape(input), \ params->stride_width, params->stride_height, \ data->padding.width, data->padding.height, \ params->filter_width, params->filter_height, activation_min, \ activation_max, GetTensorData(output), \ - GetTensorDims(output)) + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -212,14 +214,15 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_L2_POOL(type) \ - type::L2Pool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_L2_POOL(type) \ + type::L2Pool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2_POOL(reference_ops); } else { diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a539c47a8fbe392e0e6542ab8ffb9065b550485 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow.cc @@ -0,0 +1,143 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pow { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for pow op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteInt32 && type != kTfLiteFloat32) { + context->ReportError(context, "Unsupported data type %d.", type); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output, bool requires_broadcast) { + if (requires_broadcast) { + reference_ops::BroadcastPow(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), + GetTensorDims(output)); + } else { + reference_ops::Pow(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output)); + } +} + +TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) { + const int64_t num_elements = NumElements(input); + const int32_t* data = GetTensorData(input); + for (int i = 0; i < num_elements; ++i) { + if (data[i] < 0) { + context->ReportError(context, + "POW does not support negative value for int32."); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (output->type) { + case kTfLiteInt32: { + // TensorFlow does not support negative for int32. + TF_LITE_ENSURE_OK(context, CheckValue(context, input2)); + PowImpl(input1, input2, output, data->requires_broadcast); + break; + } + case kTfLiteFloat32: { + PowImpl(input1, input2, output, data->requires_broadcast); + break; + } + default: { + context->ReportError(context, "Unsupported data type: %d", output->type); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +} // namespace +} // namespace pow + +TfLiteRegistration* Register_POW() { + static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..474d323bc3a1a0f224aa0575a5bbd35394aa2f53 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class PowOpModel : public SingleOpModel { + public: + PowOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_POW, BuiltinOptions_PowOptions, + CreatePowOptions(builder_).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(PowOpModel, Simple) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8)); +} + +TEST(PowOpModel, NegativeAndZeroValue) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {0, 2, -7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 0}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1)); +} + +TEST(PowOpModel, Float) { + PowOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor(model.input2(), {0.5, 2.7, 3.1, 3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 0.08424846, 0.33098164, 277.313}, 1e-3))); +} + +TEST(PowOpModel, NegativeFloatTest) { + PowOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor(model.input2(), {0.5, -2.7, 3.1, -3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 11.869653, 0.33098164, 0.003606}, 1e-3))); +} + +TEST(PowOpModel, BroadcastTest) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1}}, {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/reduce.cc similarity index 72% rename from tensorflow/contrib/lite/kernels/mean.cc rename to tensorflow/contrib/lite/kernels/reduce.cc index 03e5db24de3f3c2d4e17df21bc0b592a02078d6b..31c331a8c61ded203af9ff2ae127cb6f985e2932 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -25,21 +25,21 @@ limitations under the License. namespace tflite { namespace ops { namespace builtin { -namespace mean { +namespace reduce { -// This file has reference implementation of Mean. +// This file has reference implementation of reduce_* operators. enum KernelType { kReference, }; -struct MeanContext { - MeanContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); axis = GetInput(context, node, 1); output = GetOutput(context, node, 0); } - TfLiteMeanParams* params; + TfLiteReducerParams* params; const TfLiteTensor* input; const TfLiteTensor* axis; TfLiteTensor* output; @@ -58,7 +58,7 @@ void Free(TfLiteContext* context, void* buffer) { } // Resizes the temp tensor that stores resolved axis. -TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context, TfLiteTensor* resolved_axis) { TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); axis_size->data[0] = static_cast(NumElements(op_context->axis)); @@ -66,7 +66,7 @@ TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, } // Resizes the temp tensor that stores temp sum of reduced elements. -TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context, TfLiteTensor* temp_sum) { TfLiteIntArray* size = TfLiteIntArrayCreate(1); size->data[0] = static_cast(NumElements(op_context->output)); @@ -74,8 +74,7 @@ TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, } // Resizes output array based on the input size and resolved axis. -TfLiteStatus ResizeOutputTensor(TfLiteContext* context, - MeanContext* op_context) { +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { size_t num_axis = NumElements(op_context->axis); const TfLiteIntArray* input_dims = op_context->input->dims; int input_num_dims = NumDimensions(op_context->input); @@ -140,7 +139,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, // Initializes temp tensors to store index and resolved axis. TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, - MeanContext* op_context) { + OpContext* op_context) { // Creates a temp index to iterate through input data. int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); @@ -180,33 +179,44 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - MeanContext op_context(context, node); + OpContext op_context(context, node); TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); // Leaves work to Eval if axis is not constant; else resizes output. if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(op_context.output); SetTensorToDynamic(resolved_axis); - SetTensorToDynamic(temp_sum); return kTfLiteOk; } resolved_axis->allocation_type = kTfLiteArenaRw; TF_LITE_ENSURE_OK(context, ResizeTempAxis(context, &op_context, resolved_axis)); TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + return kTfLiteOk; +} + +TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); + + // reduce_mean requires a buffer to store intermediate sum result. + OpContext op_context(context, node); + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); + if (!IsConstantTensor(op_context.axis)) { + SetTensorToDynamic(temp_sum); + return kTfLiteOk; + } temp_sum->allocation_type = kTfLiteArenaRw; return ResizeTempSum(context, &op_context, temp_sum); } template -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - MeanContext op_context(context, node); +TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); @@ -255,16 +265,75 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #undef TF_LITE_MEAN return kTfLiteOk; } -} // namespace mean + +template +TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int num_axis = static_cast(NumElements(op_context.axis)); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_SUM(kernel_type, data_type) \ + kernel_type::Sum<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, + op_context.output->params.scale); + TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, + op_context.output->params.zero_point); + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t)); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_SUM + return kTfLiteOk; +} + +} // namespace reduce TfLiteRegistration* Register_MEAN_REF() { - static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare, - mean::Eval}; + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareMean, + reduce::EvalMean}; + return &r; +} + +TfLiteRegistration* Register_SUM_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalSum}; return &r; } // TODO(kanlig): add optimized implementation of Mean. TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } +TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc similarity index 53% rename from tensorflow/contrib/lite/kernels/mean_test.cc rename to tensorflow/contrib/lite/kernels/reduce_test.cc index 79c9957f76fdb994be0a71f2e90b883435de4815..9e946822c686f6f20505d60b6161239624c94696 100644 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ b/tensorflow/contrib/lite/kernels/reduce_test.cc @@ -23,7 +23,7 @@ namespace { using ::testing::ElementsAreArray; -class BaseMeanOpModel : public SingleOpModel { +class BaseOpModel : public SingleOpModel { public: void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } @@ -53,7 +53,7 @@ class BaseMeanOpModel : public SingleOpModel { }; // Model for the tests case where axis is a const tensor. -class MeanOpConstModel : public BaseMeanOpModel { +class MeanOpConstModel : public BaseOpModel { public: MeanOpConstModel(const TensorData& input, const TensorData& output, std::initializer_list axis_shape, @@ -61,26 +61,59 @@ class MeanOpConstModel : public BaseMeanOpModel { input_ = AddInput(input); axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); BuildInterpreter({GetShape(input_)}); } }; // Model for the tests case where axis is a dynamic tensor. -class MeanOpDynamicModel : public BaseMeanOpModel { +class MeanOpDynamicModel : public BaseOpModel { public: MeanOpDynamicModel(const TensorData& input, const TensorData& output, const TensorData& axis, bool keep_dims) { input_ = AddInput(input); axis_ = AddInput(axis); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); BuildInterpreter({GetShape(input_)}); } }; +// Model for the tests case where axis is a const tensor. +class SumOpConstModel : public BaseOpModel { + public: + SumOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class SumOpDynamicModel : public BaseOpModel { + public: + SumOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// for quantized Add, the error shouldn't exceed step +float GetTolerance(int min, int max) { return (max - min) / 255.0; } + +// Tests for reduce_mean TEST(ConstFloatMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, @@ -149,8 +182,6 @@ TEST(DynamicFloatMeanOpTest, Scale) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); } -// for quantized Add, the error shouldn't exceed step -float GetTolerance(int min, int max) { return (max - min) / 255.0; } TEST(ConstUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); @@ -209,6 +240,135 @@ TEST(DynamicUint8MeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance))); } +// Tests for reduce_sum + +TEST(ConstFloatSumOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(ConstFloatSumOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(DynamicFloatSumOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, Scale) { + std::initializer_list data = {9.527}; + SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance))); +} + +TEST(ConstUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177}, + kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::initializer_list data = {1.3, -4.8, -3.6, 0.24}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, + {TensorType_UINT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::initializer_list axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::initializer_list data = {11.14, -0.14, 7.423, 0.879}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, + {TensorType_UINT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 7bb28d4de7402a45954691a2e031e3b6b7433ffb..0ca08cd8f38216549b4383ebaacbf4c54442cd97 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -22,6 +22,7 @@ namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); TfLiteRegistration* Register_MFCC(); +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); } // namespace custom @@ -88,6 +89,7 @@ TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); +TfLiteRegistration* Register_SUM(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); @@ -96,6 +98,10 @@ TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); TfLiteRegistration* Register_EQUAL(); TfLiteRegistration* Register_NOT_EQUAL(); +TfLiteRegistration* Register_SQRT(); +TfLiteRegistration* Register_RSQRT(); +TfLiteRegistration* Register_SHAPE(); +TfLiteRegistration* Register_POW(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -117,7 +123,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); - AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); @@ -170,16 +178,23 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_SUM, Register_SUM()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); + AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); + AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); + AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); + AddBuiltin(BuiltinOperator_POW, Register_POW()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); + AddCustom("TFLite_Detection_PostProcess", + tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } } // namespace builtin diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc index cfe24a5fc92765747d1c75bc3e6964b959e2205d..4664b9acb444747167f991944ddc120e9941ccd6 100644 --- a/tensorflow/contrib/lite/kernels/select_test.cc +++ b/tensorflow/contrib/lite/kernels/select_test.cc @@ -88,11 +88,11 @@ TEST(SelectOpTest, SelectUInt8) { TensorType_UINT8); model.PopulateTensor(model.input1(), {false, true, false, false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } @@ -101,11 +101,11 @@ TEST(SelectOpTest, SelectInt32) { TensorType_INT32); model.PopulateTensor(model.input1(), {false, true, false, false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } @@ -113,11 +113,11 @@ TEST(SelectOpTest, RankOneSelectInt32) { SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32); model.PopulateTensor(model.input1(), {false, true}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1})); } @@ -125,11 +125,11 @@ TEST(SelectOpTest, RankZeroSelectInt32) { SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32); model.PopulateTensor(model.input1(), {false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1})); } diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbcd2ef004f490f00193153be7a2cfda83e73c24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace shape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +template +void ExtractShape(const TfLiteTensor* input, OutType* output_data) { + for (int i = 0; i < NumDimensions(input); ++i) { + output_data[i] = SizeOfDimension(input, i); + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + auto* params = reinterpret_cast(node->builtin_data); + switch (params->out_type) { + case kTfLiteInt32: + output->type = kTfLiteInt32; + break; + case kTfLiteInt64: + output->type = kTfLiteInt64; + break; + default: + context->ReportError(context, "Unknown shape output data type: %d", + params->out_type); + return kTfLiteError; + } + + // Shape always produces a 1-dimensional output tensor, where each output + // element is the length of the corresponding input tensor's dimension. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); + output_size->data[0] = NumDimensions(input); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TFLITE_DCHECK_EQ(NumDimensions(output), 1); + TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); + + switch (output->type) { + case kTfLiteInt32: + ExtractShape(input, GetTensorData(output)); + break; + case kTfLiteInt64: + ExtractShape(input, GetTensorData(output)); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace shape + +TfLiteRegistration* Register_SHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/contrib/lite/kernels/shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27b48f4e992a8f02d56815bd1bd9074f5b41f400 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class ShapeOpModel : public SingleOpModel { + public: + ShapeOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type) { + input_ = AddInput(input_type); + output_ = AddOutput(output_type); + SetBuiltinOp(BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions, + CreateShapeOptions(builder_, output_type).Union()); + BuildInterpreter({input_shape}); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int input() { return input_; } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ShapeOpTest, OutTypeInt) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, OutTypeInt64) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT64); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, ScalarTensor) { + ShapeOpModel model({}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_EQ(model.GetOutputSize(), 0); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(ShapeOpTest, EmptyTensor) { + ShapeOpModel model({1, 0}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc index 6c5338ff0fd26337c9adc8e0b94a0a88edfde37f..727822f6beaa8a63ca8f1b57ba4993d2e59f7e0b 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -92,10 +92,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), @@ -120,10 +119,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index 43387df9ceb4d54a2784c3fa4718a95262948729..b14448604123253bac9c50c21f047891721ab122 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -76,8 +76,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); auto input_type = op_context.input->type; - TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || + input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16); for (int i = 0; i < NumOutputs(node); ++i) { GetOutput(context, node, i)->type = input_type; } @@ -137,9 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT(uint8_t); break; } + case kTfLiteInt16: { + TF_LITE_SPLIT(int16_t); + break; + } default: context->ReportError( - context, "Only float32 and uint8 are currently supported, got %d.", + context, + "Only float32, uint8 and int16 are currently supported, got %d.", op_context.input->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index 725dd8105ab9506d5203ed38a11f8e06abdab603..bed2117f9ae3a64e963478eb03b46f0547f4c05f 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -121,10 +121,19 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, int32_t begin = GetBeginValueAtIndex(op_context, idx); int32_t end = GetEndValueAtIndex(op_context, idx); + // When shrinking an axis, the end position does not matter (and can be + // incorrect when negative indexing is used, see Issue #19260). Always use + // begin + 1 to generate a length 1 slice, since begin has + // already been adjusted for negative indices by GetBeginValueAtIndex. + const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx); + if (shrink_axis) { + end = begin + 1; + } + // This is valid for both positive and negative strides int32_t dim_shape = ceil((end - begin) / static_cast(stride)); dim_shape = dim_shape < 0 ? 0 : dim_shape; - if (!(op_context->params->shrink_axis_mask & (1 << idx))) { + if (!shrink_axis) { output_shape_vector.push_back(dim_shape); } } @@ -204,13 +213,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { int begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); - -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), begin_mask, \ - end_mask, starts, stops, strides, \ - GetTensorData(op_context.output), \ - GetTensorDims(op_context.output)) + int shrink_axis_mask = + ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \ + starts, stops, strides, GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) switch (op_context.input->type) { case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index cc39179bc705aa1083e74b06f8f7f3fb45e9f616..c5d4f9affb46c82b4dec15bc0653d7315d132335 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. namespace tflite { namespace { -using ::int32; using ::testing::ElementsAreArray; template data) { PopulateTensor(input_, data); } - void SetBegin(std::initializer_list data) { - PopulateTensor(begin_, data); + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); } - void SetEnd(std::initializer_list data) { - PopulateTensor(end_, data); + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); } - void SetStrides(std::initializer_list data) { - PopulateTensor(strides_, data); + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); } std::vector GetOutput() { @@ -384,6 +383,45 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); } +TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { + // This is equivalent to tf.range(4)[-1]. + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-1}); + m.SetEnd({0}); + m.SetStrides({1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-2, -1}); + m.SetEnd({-1, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][:, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({0, -1}); + m.SetEnd({0, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3})); +} + TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); @@ -395,17 +433,6 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } -TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); - m.SetInput({1, 2, 3, 4}); - m.SetBegin({-2}); - m.SetEnd({-3}); - m.SetStrides({-1}); - m.Invoke(); - EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); -} - TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -538,7 +565,7 @@ TEST(StridedSliceOpTest, RunTwice) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index a8b803589962032db3ed579d31e8b736c3afada0..1247525d416e8166a9e2e1d67c7907c00b0f6723 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -83,8 +83,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_SUB(type, opname) \ type::opname(GetTensorData(input1), GetTensorDims(input1), \ GetTensorData(input2), GetTensorDims(input2), \ diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 308860c299e9d74729d35b760e0f605437872c92..43ac3a2ce86df6dc9a0dd914851174aaf33b25be 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +// SVDF op that compresses a fully connected op via low-rank matrix +// factorization. See https://research.google.com/pubs/archive/43813.pdf for +// details. #include #include #include @@ -32,6 +36,67 @@ namespace ops { namespace builtin { namespace svdf { +namespace { + +struct OpData { + int scratch_tensor_index; + bool float_weights_time_initialized; +}; + +static inline void ApplyTimeWeightsBiasAndActivation( + int batch_size, int memory_size, int num_filters, int num_units, int rank, + const TfLiteTensor* weights_time, const TfLiteTensor* bias, + TfLiteFusedActivation activation, TfLiteTensor* state, + TfLiteTensor* scratch, TfLiteTensor* output) { + // Compute matmul(state, weights_time). + // The right most column is used to save temporary output (with the size of + // num_filters). This is achieved by starting at state->data.f and having the + // stride equal to memory_size. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::BatchVectorBatchVectorDotProduct( + weights_time->data.f, state_ptr_batch, memory_size, num_filters, + scratch_ptr_batch, /*result_stride=*/1); + } + + // Initialize output with bias if provided. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Reduction sum. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output->data.f + b * num_units; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, + num_units, rank); + } + + // Apply activation. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output->data.f + b * num_units; + tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, + activation, output_ptr_batch); + } + + // Left shift the state to make room for next cycle's activation. + // TODO(alanchiao): explore collapsing this into a single loop. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int f = 0; f < num_filters; ++f) { + tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, + /*shift_value=*/0.0); + state_ptr_batch += memory_size; + } + } +} + +} // namespace + constexpr int kInputTensor = 0; constexpr int kWeightsFeatureTensor = 1; constexpr int kWeightsTimeTensor = 2; @@ -40,29 +105,34 @@ constexpr int kStateTensor = 0; constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData; + op_data->float_weights_time_initialized = false; + context->AddTensors(context, /*tensors_to_add=*/4, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - int* scratch_tensor_index = reinterpret_cast(node->user_data); + const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = GetInput(context, node, kWeightsFeatureTensor); const TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + // Check all the parameters of tensor match within themselves and match the // input configuration. const int rank = params->rank; @@ -103,10 +173,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = + (input->type == kTfLiteFloat32 && weights_feature->type == kTfLiteUInt8); + // Resize scratch. TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(4); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = scratch_tensor_index; TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); scratch_size_array->data[0] = batch_size; @@ -118,24 +196,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, scratch_size_array)); - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* weights_feature = - GetInput(context, node, kWeightsFeatureTensor); - const TfLiteTensor* weights_time = - GetInput(context, node, kWeightsTimeTensor); + if (is_hybrid_op) { + // Tell interpreter to allocate temporary tensors to store quantized values + // of input tensors. + node->temporaries->data[1] = scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } - TfLiteTensor* state = GetOutput(context, node, kStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + // Tell interpreter to allocate temporary tensors to store scaling factors. + node->temporaries->data[2] = scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + // Used to store dequantized weights_time matrix for hybrid computation + // of matmul(state, weights_time), which occurs in floating point. + node->temporaries->data[3] = scratch_tensor_index + 3; + TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3); + float_weights_time->type = kTfLiteFloat32; + // Persistent so that we can compute the dequantized weights only once. + float_weights_time->allocation_type = kTfLiteArenaRwPersistent; + if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) { + TfLiteIntArray* float_weights_time_size = + TfLiteIntArrayCopy(weights_time->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, float_weights_time, + float_weights_time_size)); + } + } + return kTfLiteOk; +} +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* input, + const TfLiteTensor* weights_feature, + const TfLiteTensor* weights_time, + const TfLiteTensor* bias, const TfLiteSVDFParams* params, + TfLiteTensor* scratch, TfLiteTensor* state, + TfLiteTensor* output) { const int rank = params->rank; const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; @@ -146,67 +256,150 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Clear the activation (state left most column). // TODO(ghodrat): Add a test which initialize state with invalid values in // left most column and make sure it passes. - for (int b = 0; b < batch_size; b++) { + for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - for (int c = 0; c < num_filters; c++) { + for (int c = 0; c < num_filters; ++c) { float* state_ptr = state_ptr_batch + c * memory_size; state_ptr[memory_size - 1] = 0.0; } } // Compute conv1d(inputs, weights_feature). - // The state left most column is used to save current cycle activation. This + // The state right most column is used to save current cycle activation. This // is achieved by starting at state->data.f[memory_size - 1] and having the // stride equal to memory_size. tensor_utils::MatrixBatchVectorMultiplyAccumulate( weights_feature->data.f, num_filters, input_size, input->data.f, batch_size, &state->data.f[memory_size - 1], memory_size); - // Compute matmul(state, weights_time). - // The right most column is used to save temporary output (with the size of - // num_filters). This is achieved by starting at state->data.f and having the - // stride equal to memory_size. - for (int b = 0; b < batch_size; b++) { + ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters, + num_units, rank, weights_time, bias, + params->activation, state, scratch, output); + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, + const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, + const TfLiteTensor* bias, const TfLiteSVDFParams* params, + TfLiteTensor* scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) { + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Initialize the pointer to input. + const float* input_ptr_batch = input->data.f; + + // Initialize the pointer to storage for quantized values and + // scaling factors. + int8_t* quantized_input_ptr_batch = + reinterpret_cast(input_quantized->data.uint8); + + float* scaling_factors_ptr = scaling_factors->data.f; + + // Other initializations. + const int8_t* weights_feature_ptr = + reinterpret_cast(weights_feature->data.uint8); + const float weights_feature_scale = weights_feature->params.scale; + + // Clear the activation (state left most column). + // TODO(ghodrat): Add a test which initialize state with invalid values in + // left most column and make sure it passes. + for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - float* scratch_ptr_batch = scratch->data.f + b * num_filters; - tensor_utils::BatchVectorBatchVectorDotProduct( - weights_time->data.f, state_ptr_batch, memory_size, num_filters, - scratch_ptr_batch, /*result_stride=*/1); + for (int c = 0; c < num_filters; ++c) { + float* state_ptr = state_ptr_batch + c * memory_size; + state_ptr[memory_size - 1] = 0.0; + } } - // Initialize output with bias if provided. - if (bias) { - tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, - output->data.f); - } else { - tensor_utils::ZeroVector(output->data.f, batch_size * num_units); - } + if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) { + // Quantize input from float to int8. + float unused_min, unused_max; + for (int b = 0; b < batch_size; ++b) { + const int offset = b * input_size; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors_ptr[b]); + scaling_factors_ptr[b] *= weights_feature_scale; + } - // Reduction sum - for (int b = 0; b < batch_size; b++) { - float* output_ptr_batch = output->data.f + b * num_units; - float* scratch_ptr_batch = scratch->data.f + b * num_filters; - tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, - num_units, rank); + // Compute conv1d(inputs, weights_feature). + // The state right most column is used to save current cycle activation. + // This is achieved by starting at state->data.f[memory_size - 1] and having + // the stride equal to memory_size. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch, + scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1], + memory_size); } - // Apply activation. - for (int b = 0; b < batch_size; b++) { - float* output_ptr_batch = output->data.f + b * num_units; - tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, - params->activation, output_ptr_batch); - } + // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying + // time weights so that the inner loop multiplies eight elements at a time. + ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters, + num_units, rank, weights_time, bias, + params->activation, state, scratch, output); + return kTfLiteOk; +} - // Right shift the state. - for (int b = 0; b < batch_size; b++) { - float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - for (int f = 0; f < num_filters; f++) { - tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, - /*shift_value=*/0.0); - state_ptr_batch += memory_size; +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* state = GetOutput(context, node, kStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (weights_feature->type) { + case kTfLiteFloat32: { + return EvalFloat(context, node, input, weights_feature, weights_time, + bias, params, scratch, state, output); + break; } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + TfLiteTensor* float_weights_time = + GetTemporary(context, node, /*index=*/3); + + // Dequantize weights time. + // TODO(alanchiao): this dequantization initialization only needs to + // happen once per model and should theoretically be placed in either Init + // or Prepare. However, TFLite doesn't allocate float_weights_time until + // the Eval function. + // TODO(alanchiao): refactor logic out into dequantize function. + if (!op_data->float_weights_time_initialized) { + const float inv_scale = 1.0 / weights_time->params.scale; + const int8_t* weights_time_ptr = + reinterpret_cast(weights_time->data.uint8); + for (int i = 0; i < NumElements(float_weights_time); ++i) { + float_weights_time->data.f[i] = weights_time_ptr[i] * inv_scale; + } + op_data->float_weights_time_initialized = true; + } + return EvalHybrid(context, node, input, weights_feature, + float_weights_time, bias, params, scratch, + scaling_factors, input_quantized, state, output); + break; + } + default: + context->ReportError(context, "Type %d not currently supported.", + weights_feature->type); + return kTfLiteError; } - return kTfLiteOk; } } // namespace svdf diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index 0f166dc69b95f3459388135b3a6c4d9b73a31cb4..06df509d32dacc25fbcf84606b5218697c831e96 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -126,17 +126,20 @@ static float svdf_golden_output_rank_2[] = { }; // Derived class of SingleOpModel, which is used to test SVDF TFLite op. -class SVDFOpModel : public SingleOpModel { +class BaseSVDFOpModel : public SingleOpModel { public: - SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank, + TensorType weights_feature_type = TensorType_FLOAT32, + TensorType weights_time_type = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(input_size), memory_size_(memory_size), rank_(rank) { input_ = AddInput(TensorType_FLOAT32); - weights_feature_ = AddInput(TensorType_FLOAT32); - weights_time_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(weights_feature_type); + weights_time_ = AddInput(weights_time_type); bias_ = AddNullInput(); state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -182,7 +185,7 @@ class SVDFOpModel : public SingleOpModel { int num_units() { return units_; } int num_batches() { return batches_; } - private: + protected: int input_; int weights_feature_; int weights_time_; @@ -197,7 +200,61 @@ class SVDFOpModel : public SingleOpModel { int rank_; }; -TEST(SVDFOpTest, BlackBoxTestRank1) { +class SVDFOpModel : public BaseSVDFOpModel { + public: + using BaseSVDFOpModel::BaseSVDFOpModel; +}; + +class HybridSVDFOpModel : public BaseSVDFOpModel { + public: + HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank) + : BaseSVDFOpModel(batches, units, input_size, memory_size, rank, + TensorType_UINT8, TensorType_UINT8) {} + + void SetWeightsFeature(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_feature_, f); + } + + void SetWeightsTime(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_time_, f); + } +}; + +class SVDFOpTest : public ::testing::Test { + protected: + void VerifyGoldens(float golden_input[], float golden_output[], + int golden_size, BaseSVDFOpModel* svdf, + float tolerance = 1e-5) { + const int svdf_num_batches = svdf->num_batches(); + const int svdf_input_size = svdf->input_size(); + const int svdf_num_units = svdf->num_units(); + const int input_sequence_size = + golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF + // op and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = + golden_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf->SetInput(0, batch_start, batch_end); + + svdf->Invoke(); + + const float* golden_start = + golden_output + i * svdf_num_units * svdf_num_batches; + const float* golden_end = + golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST_F(SVDFOpTest, BlackBoxTestRank1) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/1); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, @@ -218,31 +275,11 @@ TEST(SVDFOpTest, BlackBoxTestRank1) { -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf); } -TEST(SVDFOpTest, BlackBoxTestRank2) { +TEST_F(SVDFOpTest, BlackBoxTestRank2) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/2); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, @@ -278,28 +315,75 @@ TEST(SVDFOpTest, BlackBoxTestRank2) { 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.00294435); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.00625109); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index d23ec201b41887b0682242687fc938d76d058c44..9156917140b5af6c0f38c878ab77fef7f93b049a 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -32,8 +32,8 @@ std::vector> ArrayFloatNear(const std::vector& values, return matchers; } -int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t, {}); +int SingleOpModel::AddInput(const TensorData& t, bool is_variable) { + int id = AddTensor(t, {}, is_variable); inputs_.push_back(id); return id; } @@ -120,6 +120,7 @@ void SingleOpModel::BuildInterpreter( CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; + interpreter_->ResetVariableTensorsToZero(); } void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index db80c0082c394a2cb2f9388d3db5bd1a7cbe6266..bedbe93ae65662647f6a0fb0c9c6a6a921e148bb 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -126,8 +126,10 @@ class SingleOpModel { SingleOpModel& operator=(const SingleOpModel&) = delete; // Add a TensorType input tensor and return its index. - int AddInput(TensorType type) { return AddInput(TensorData{type}); } - int AddInput(const TensorData& t); + int AddInput(TensorType type, bool is_variable = false) { + return AddInput(TensorData{type}, is_variable); + } + int AddInput(const TensorData& t, bool is_variable = false); // Templated version of AddConstInput(). template @@ -146,20 +148,18 @@ class SingleOpModel { int AddOutput(const TensorData& t); template - void QuantizeAndPopulate(int index, std::initializer_list data) { + void QuantizeAndPopulate(int index, const std::vector& data) { TfLiteTensor* t = interpreter_->tensor(index); auto q = Quantize(data, t->params.scale, t->params.zero_point); PopulateTensor(index, 0, q.data(), q.data() + q.size()); } - void SymmetricQuantizeAndPopulate(int index, - std::initializer_list data) { + void SymmetricQuantizeAndPopulate(int index, const std::vector& data) { TfLiteTensor* t = interpreter_->tensor(index); - std::vector values(data); - const int length = values.size(); + const int length = data.size(); std::vector q(length); float min, max, scaling_factor; - tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min, + tensor_utils::SymmetricQuantizeFloats(data.data(), length, q.data(), &min, &max, &scaling_factor); // Update quantization params. t->params.scale = scaling_factor; @@ -196,8 +196,22 @@ class SingleOpModel { } // Populate the tensor given its index. + // TODO(b/110696148) clean up and merge with vector-taking variant below. + template + void PopulateTensor(int index, const std::initializer_list& data) { + T* v = interpreter_->typed_tensor(index); + CHECK(v) << "No tensor with index '" << index << "'."; + for (T f : data) { + *v = f; + ++v; + } + } + + // Populate the tensor given its index. + // TODO(b/110696148) clean up and merge with initializer_list-taking variant + // above. template - void PopulateTensor(int index, std::initializer_list data) { + void PopulateTensor(int index, const std::vector& data) { T* v = interpreter_->typed_tensor(index); CHECK(v) << "No tensor with index '" << index << "'."; for (T f : data) { @@ -260,7 +274,8 @@ class SingleOpModel { } template - int AddTensor(TensorData t, std::initializer_list data) { + int AddTensor(TensorData t, std::initializer_list data, + bool is_variable = false) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -277,6 +292,9 @@ class SingleOpModel { } else if (t.type == TensorType_INT32) { std::tie(t.scale, t.zero_point) = QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT16) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); } else { LOG(FATAL) << "No support for the requested quantized type"; } @@ -309,7 +327,7 @@ class SingleOpModel { tensors_.push_back(CreateTensor(builder_, builder_.CreateVector(t.shape), t.type, /*buffer=*/buffer_id, - /*name=*/0, q_params)); + /*name=*/0, q_params, is_variable)); tensor_data_[id] = t; diff --git a/tensorflow/contrib/lite/kernels/test_util_test.cc b/tensorflow/contrib/lite/kernels/test_util_test.cc index 1e10e89061213b6fcabd404310893dd97a51d83f..236580347254d336609a3081736f54e069b5cb5a 100644 --- a/tensorflow/contrib/lite/kernels/test_util_test.cc +++ b/tensorflow/contrib/lite/kernels/test_util_test.cc @@ -22,22 +22,22 @@ using ::testing::ElementsAreArray; TEST(TestUtilTest, QuantizeVector) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/1.0, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 1, 1, 255}; + auto q_data = Quantize(data, /*scale=*/1.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 1, 1, 255}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } TEST(TestUtilTest, QuantizeVectorScalingDown) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/10.0, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 0, 0, 100}; + auto q_data = Quantize(data, /*scale=*/10.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 0, 0, 100}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } TEST(TestUtilTest, QuantizeVectorScalingUp) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/0.1, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 5, 10, 255}; + auto q_data = Quantize(data, /*scale=*/0.1, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 5, 10, 255}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc index a134a75d56ae03a5d03a3cdf632146474b863971..4f78c224e54f0c71bc6622134a1c8e4142c22daa 100644 --- a/tensorflow/contrib/lite/kernels/tile_test.cc +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -38,27 +38,27 @@ class TileOpModel : public SingleOpModel { PopulateTensor(input_, data); } - void SetInputUInt8(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); } - void SetInputInt32(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); } void SetInputInt64(std::initializer_list data) { PopulateTensor(input_, data); } - void SetMultipliers(std::initializer_list data) { - PopulateTensor(multipliers_, data); + void SetMultipliers(std::initializer_list data) { + PopulateTensor(multipliers_, data); } std::vector GetOutputFloat() { return ExtractVector(output_); } - std::vector GetOutputUInt8() { return ExtractVector(output_); } + std::vector GetOutputUInt8() { return ExtractVector(output_); } - std::vector GetOutputInt32() { return ExtractVector(output_); } + std::vector GetOutputInt32() { return ExtractVector(output_); } std::vector GetOutputInt64() { return ExtractVector(output_); diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc index 212f8acc76d4afba56933029175f69b34ea87a3e..2abb89b617742b33b9280b15ad379422c5c9b207 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -42,32 +42,32 @@ class TopKV2OpModel : public SingleOpModel { PopulateTensor(input_, data); } - void SetInputUInt8(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); } - void SetInputInt32(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); } void SetInputInt64(std::initializer_list data) { PopulateTensor(input_, data); } - std::vector GetIndexes() { - return ExtractVector(output_indexes_); + std::vector GetIndexes() { + return ExtractVector(output_indexes_); } std::vector GetValuesFloat() { return ExtractVector(output_values_); } - std::vector GetValuesUInt8() { - return ExtractVector(output_values_); + std::vector GetValuesUInt8() { + return ExtractVector(output_values_); } - std::vector GetValuesInt32() { - return ExtractVector(output_values_); + std::vector GetValuesInt32() { + return ExtractVector(output_values_); } std::vector GetValuesInt64() { @@ -119,7 +119,7 @@ TEST(TopKV2OpTest, VectorFloat) { EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(ArrayFloatNear({0.8, 0.2}))); } -// Check that uint8 works. +// Check that uint8_t works. TEST(TopKV2OpTest, TypeUint8) { TopKV2OpModel m({2, 3}, TensorType_UINT8, 2); m.SetInputUInt8({1, 2, 3, 251, 250, 249}); @@ -128,7 +128,7 @@ TEST(TopKV2OpTest, TypeUint8) { EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250})); } -// Check that int32 works. +// Check that int32_t works. TEST(TopKV2OpTest, TypeInt32) { TopKV2OpModel m({2, 3}, TensorType_INT32, 2); m.SetInputInt32({1, 2, 3, 10251, 10250, 10249}); diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index bc62e4cc2d8af9b1c242900a9730f4fae3b92a6c..c448fb71db204494042192d6a75ac4d600467e47 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -63,6 +63,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_BOOL: *type = kTfLiteBool; break; + case TensorType_COMPLEX64: + *type = kTfLiteComplex64; + break; default: error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", EnumNameTensorType(tensor_type), tensor_type); @@ -444,6 +447,18 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, op->builtin_options_as_FullyConnectedOptions()) { params->activation = parse_activation( fully_connected_params->fused_activation_function()); + switch (fully_connected_params->weights_format()) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + params->weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + break; + default: + error_reporter->Report("Unhandled fully-connected weights format."); + return kTfLiteError; + } } *builtin_data = reinterpret_cast(params); break; @@ -597,9 +612,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_MEAN: { - auto* params = MallocPOD(); - if (auto* schema_params = op->builtin_options_as_MeanOptions()) { + case BuiltinOperator_MEAN: + case BuiltinOperator_SUM: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { params->keep_dims = schema_params->keep_dims(); } *builtin_data = reinterpret_cast(params); @@ -667,6 +683,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { + ConvertTensorType(schema_params->out_type(), ¶ms->out_type, + error_reporter); + } + *builtin_data = static_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); @@ -703,14 +728,17 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_RELU: case BuiltinOperator_RELU6: case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RSQRT: case BuiltinOperator_SELECT: case BuiltinOperator_SIN: case BuiltinOperator_SLICE: case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_SQRT: case BuiltinOperator_TANH: case BuiltinOperator_TILE: case BuiltinOperator_TOPK_V2: case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_POW: break; } return kTfLiteOk; @@ -733,7 +761,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes( } const TfLiteRegistration* registration = - flatbuffer_op_index_to_registration_[op->opcode_index()]; + flatbuffer_op_index_to_registration_[index]; if (registration == nullptr) { error_reporter_->Report("Skipping op for opcode_index %d\n", index); status = kTfLiteError; @@ -963,7 +991,7 @@ TfLiteStatus InterpreterBuilder::operator()( variables.push_back(i); } } - (**interpreter).SetVariables(variables); + (**interpreter).SetVariables(std::move(variables)); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD index f8767b443a2aa64b666c3b6bfb7db30cc0be62ea..f18a2ca07a5f66b760e96a6d9a57db8d6c26b7b9 100644 --- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 999c31d4bff9279810a3661f0bb342cc4ef3ddaa..7627d89c091d08390021bb47c640749956d8796d 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -215,6 +215,17 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, augmented_inputs.push_back(next_id++); }; + auto add_vector_int32 = [&](const int* values, uint32_t num_values) { + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_INT32, + .dimensionCount = 1, + .dimensions = &num_values}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, next_id, values, sizeof(int32_t) * num_values)); + augmented_inputs.push_back(next_id++); + }; + // Handle state tensors of RNN, LSTM, SVDF. // For each state_out tensor, a corresponding state_in operand needs to be // created for NNAPI. @@ -312,7 +323,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, }; auto add_mean_params = [&add_scalar_int32](void* data) { - auto builtin = reinterpret_cast(data); + auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->keep_dims); }; @@ -327,6 +338,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_scalar_int32(builtin->activation); }; + auto add_squeeze_params = [&](void* data) { + const auto* builtin = reinterpret_cast(data); + // Note that we add the squeeze dimensions even if the dimensions were + // unspecified (empty), as NNAPI requires the operand. + add_vector_int32(builtin->squeeze_dims, + static_cast(builtin->num_squeeze_dims)); + }; + // Handle optional input tensors. auto add_optional_tensors = [&nn_model, &augmented_inputs, &next_id](int nn_type) { @@ -453,6 +472,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_SUB; break; + case tflite::BuiltinOperator_SQUEEZE: + nnapi_version = 11; // requires NNAPI 1.1 + add_squeeze_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SQUEEZE; + break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: case tflite::BuiltinOperator_HASHTABLE_LOOKUP: @@ -474,7 +498,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_TOPK_V2: case tflite::BuiltinOperator_TRANSPOSE: case tflite::BuiltinOperator_SPLIT: - case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: case tflite::BuiltinOperator_LOG_SOFTMAX: @@ -500,6 +523,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SPARSE_TO_DENSE: case tflite::BuiltinOperator_EQUAL: case tflite::BuiltinOperator_NOT_EQUAL: + case tflite::BuiltinOperator_SUM: + case tflite::BuiltinOperator_SQRT: + case tflite::BuiltinOperator_RSQRT: + case tflite::BuiltinOperator_SHAPE: + case tflite::BuiltinOperator_POW: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc index 3af809a2a1034c411881bfc6a919562d326e99cf..f1f025f777c987c5ee47bdea457a973896b9bb82 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -52,6 +52,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteBool"; case kTfLiteInt16: return "kTfLiteInt16"; + case kTfLiteComplex64: + return "kTfLiteComplex64"; } return "(invalid)"; } @@ -84,13 +86,13 @@ void PrintInterpreterState(Interpreter* interpreter) { for (int tensor_index = 0; tensor_index < interpreter->tensors_size(); tensor_index++) { TfLiteTensor* tensor = interpreter->tensor(tensor_index); - printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, - TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type), - tensor->bytes, float(tensor->bytes) / float(1 << 20)); + printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, + tensor->name, TensorTypeName(tensor->type), + AllocTypeName(tensor->allocation_type), tensor->bytes, + (static_cast(tensor->bytes) / (1 << 20))); PrintTfLiteIntVector(tensor->dims); - printf("\n"); } - + printf("\n"); for (int node_index = 0; node_index < interpreter->nodes_size(); node_index++) { const std::pair* node_and_reg = @@ -106,7 +108,4 @@ void PrintInterpreterState(Interpreter* interpreter) { } } -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h index 1b6998cda382782b974bea3d18ffb6217e8f780c..7fb4b8d8b7ae87cc6e8dd8503c8a4ce0cef2ce8d 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.h +++ b/tensorflow/contrib/lite/optional_debug_tools.h @@ -24,9 +24,6 @@ namespace tflite { // Prints a dump of what tensors and what nodes are in the interpreter. void PrintInterpreterState(Interpreter* interpreter); -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc index 45388b500c7897c8b33b49eb6ab4e9f8c4fdb37c..c37a0965884a803e82da536f73a8f32a28691651 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -78,8 +78,13 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, } else { op_name = tflite::EnumNamesBuiltinOperator()[code]; } + const char* profiling_string = + interpreter.OpProfilingString(node_reg->second, &node_reg->first); OperatorDetails details; details.name = op_name; + if (profiling_string) { + details.name += ":" + string(profiling_string); + } details.inputs = GetTensorNames(interpreter, inputs); details.outputs = GetTensorNames(interpreter, outputs); return details; diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc index 35cf780713b93db559f86dcaf62e1ac004b5049a..67a5eecfa05379c7a721e7d669fcd02602e5e369 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc @@ -31,6 +31,7 @@ namespace profiling { namespace { +#ifdef TFLITE_PROFILING_ENABLED TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0); const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1); @@ -42,20 +43,35 @@ TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +const char* SimpleOpProfilingString(const TfLiteContext* context, + const TfLiteNode* node) { + return "Profile"; +} + TfLiteRegistration* RegisterSimpleOp() { + static TfLiteRegistration registration = { + nullptr, nullptr, nullptr, + SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM, + "SimpleOpEval", 1}; + return ®istration; +} + +TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() { static TfLiteRegistration registration = {nullptr, nullptr, nullptr, SimpleOpEval, + SimpleOpProfilingString, tflite::BuiltinOperator_CUSTOM, "SimpleOpEval", 1}; return ®istration; } +#endif class SimpleOpModel : public SingleOpModel { public: - void Init(); + void Init(const std::function& registration); tflite::Interpreter* GetInterpreter() { return interpreter_.get(); } void SetInputs(int32_t x, int32_t y) { PopulateTensor(inputs_[0], {x}); @@ -68,11 +84,12 @@ class SimpleOpModel : public SingleOpModel { int output_; }; -void SimpleOpModel::Init() { +void SimpleOpModel::Init( + const std::function& registration) { inputs_[0] = AddInput({TensorType_INT32, {1}}); inputs_[1] = AddInput({TensorType_INT32, {1}}); output_ = AddOutput({TensorType_INT32, {}}); - SetCustomOp("SimpleAdd", {}, RegisterSimpleOp); + SetCustomOp("SimpleAdd", {}, registration); BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])}); } @@ -86,7 +103,28 @@ TEST(ProfileSummarizerTest, Empty) { TEST(ProfileSummarizerTest, Interpreter) { Profiler profiler; SimpleOpModel m; - m.Init(); + m.Init(RegisterSimpleOp); + auto interpreter = m.GetInterpreter(); + interpreter->SetProfiler(&profiler); + profiler.StartProfiling(); + m.SetInputs(1, 2); + m.Invoke(); + // 3 = 1 + 2 + EXPECT_EQ(m.GetOutput(), 3); + profiler.StopProfiling(); + ProfileSummarizer summarizer; + auto events = profiler.GetProfileEvents(); + EXPECT_EQ(1, events.size()); + summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); + auto output = summarizer.GetOutputString(); + // TODO(shashishekhar): Add a better test here. + ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output; +} + +TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) { + Profiler profiler; + SimpleOpModel m; + m.Init(RegisterSimpleOpWithProfilingDetails); auto interpreter = m.GetInterpreter(); interpreter->SetProfiler(&profiler); profiler.StartProfiling(); @@ -101,8 +139,10 @@ TEST(ProfileSummarizerTest, Interpreter) { summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); auto output = summarizer.GetOutputString(); // TODO(shashishekhar): Add a better test here. - ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output; + ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos) + << output; } + #endif } // namespace diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c038c88945b71f30bf091a1098dcf853f5415b1b..0ea2630f711727787332f207bdff6383aac8097c 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -25,7 +25,6 @@ import tempfile as _tempfile from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util.lazy_loader import LazyLoader @@ -135,11 +134,11 @@ def build_toco_convert_protos(input_tensors, input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Target data type of arrays in the output file. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) - inference_input_type: Target data type of input arrays. Allows for a - different type for input arrays in the case of quantization. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) input_format: Type of data to read Currently must be `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) output_format: Output file format. Currently must be `{TFLITE, @@ -202,29 +201,13 @@ def build_toco_convert_protos(input_tensors, if dump_graphviz_dir: toco.dump_graphviz_dir = dump_graphviz_dir toco.dump_graphviz_include_video = dump_graphviz_video + model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): - if input_tensor.dtype == _dtypes.float32: - tflite_input_type = lite_constants.FLOAT - elif input_tensor.dtype == _dtypes.int32: - tflite_input_type = lite_constants.INT32 - elif input_tensor.dtype == _dtypes.int64: - tflite_input_type = lite_constants.INT64 - elif input_tensor.dtype == _dtypes.uint8: - tflite_input_type = lite_constants.QUANTIZED_UINT8 - # TODO(aselle): Insert strings when they are available - else: - raise ValueError("Tensors %s not known type %r" % (input_tensor.name, - input_tensor.dtype)) - input_array = model.input_arrays.add() - if inference_type == lite_constants.QUANTIZED_UINT8: - if tflite_input_type == lite_constants.FLOAT: - tflite_input_type = lite_constants.QUANTIZED_UINT8 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] - input_array.name = tensor_name(input_tensor) input_array.shape.dims.extend(map(int, input_tensor.get_shape())) diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 779bda4c9d05fd056d6a262412fdcf0d47e7c57c..fd908234254185e0a0639618e936ca8ff58631da 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys from tensorflow.python.util.lazy_loader import LazyLoader # Lazy load since some of the performance benchmark skylark rules @@ -64,9 +65,38 @@ class Interpreter(object): raise ValueError('Can\'t both provide `model_path` and `model_content`') def allocate_tensors(self): + self._ensure_safe() if not self._interpreter.AllocateTensors(): raise ValueError('Failed to allocate tensors') + def _safe_to_run(self): + """Returns true if there exist no numpy array buffers. + + This means it is safe to run tflite calls that may destroy internally + allocated memory. This works, because in the wrapper.cc we have made + the numpy base be the self._interpreter. + """ + # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. + # If this environment is the only _interpreter, then the ref count should be + # 2 (1 in self and 1 in temporary of sys.getrefcount). + return sys.getrefcount(self._interpreter) == 2 + + def _ensure_safe(self): + """Makes sure no numpy arrays pointing to internal buffers are active. + + This should be called from any function that will call a function on + _interpreter that may reallocate memory e.g. invoke(), ... + + Raises: + RuntimeError: If there exist numpy objects pointing to internal memory + then we throw. + """ + if not self._safe_to_run(): + raise RuntimeError("""There is at least 1 reference to internal data + in the interpreter in the form of a numpy array or slice. Be sure to + only hold the function returned from tensor() if you are using raw + data access.""") + def _get_tensor_details(self, tensor_index): """Gets tensor details. @@ -109,7 +139,10 @@ class Interpreter(object): ] def set_tensor(self, tensor_index, value): - """Sets the value of the input. + """Sets the value of the input tensor. Note this copies data in `value`. + + If you want to avoid copying, you can use the `tensor()` function to get a + numpy buffer pointing to the input buffer in the tflite interpreter. Args: tensor_index: Tensor index of tensor to set. This value can be gotten from @@ -133,6 +166,7 @@ class Interpreter(object): Raises: ValueError: If the interpreter could not resize the input tensor. """ + self._ensure_safe() if not self._interpreter.ResizeInputTensor(input_index, tensor_size): raise ValueError('Failed to resize input') @@ -147,7 +181,7 @@ class Interpreter(object): ] def get_tensor(self, tensor_index): - """Sets the value of the input. + """Gets the value of the input tensor. Note this makes a copy so prefer `tensor()`. Args: tensor_index: Tensor index of tensor to get. This value can be gotten from @@ -158,6 +192,60 @@ class Interpreter(object): """ return self._interpreter.GetTensor(tensor_index) + def tensor(self, tensor_index): + """Returns function that gives a numpy view of the current tensor buffer. + + This allows reading and writing to this tensors w/o copies. This more + closely mirrors the C++ Interpreter class interface's tensor() member, hence + the name. Be careful to not hold these output references through calls + to `allocate_tensors()` and `invoke()`. + + Usage: + + interpreter.allocate_tensors() + input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) + output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) + for i in range(10): + input().fill(3.) + interpreter.invoke() + print("inference %s" % output) + + Notice how this function avoids making a numpy array directly. This is + because it is important to not hold actual numpy views to the data longer + than necessary. If you do, then the interpreter can no longer be invoked, + because it is possible the interpreter would resize and invalidate the + referenced tensors. The NumPy API doesn't allow any mutability of the + the underlying buffers. + + WRONG: + + input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() + output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() + interpreter.allocate_tensors() # This will throw RuntimeError + for i in range(10): + input.fill(3.) + interpreter.invoke() # this will throw RuntimeError since input,output + + Args: + tensor_index: Tensor index of tensor to get. This value can be gotten from + the 'index' field in get_output_details. + + Returns: + A function that can return a new numpy array pointing to the internal + TFLite tensor state at any point. It is safe to hold the function forever, + but it is not safe to hold the numpy array forever. + """ + return lambda: self._interpreter.tensor(self._interpreter, tensor_index) + def invoke(self): + """Invoke the interpreter. + + Be sure to set the input sizes, allocate tensors and fill values before + calling this. + + Raises: + ValueError: When the underlying interpreter fails raise ValueError. + """ + self._ensure_safe() if not self._interpreter.Invoke(): raise ValueError('Failed to invoke TFLite model') diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index f802edf020db8a9d4e7bb890aadaae7e34e983a8..5f1fa26c3b7f76309a6f1f80aa3c1e4889781528 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -91,5 +91,61 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertTrue((expected_output == output_data).all()) +class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + self.interpreter.allocate_tensors() + self.input0 = self.interpreter.get_input_details()[0]['index'] + self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32) + + def testTensorAccessor(self): + """Check that tensor returns a reference.""" + array_ref = self.interpreter.tensor(self.input0) + np.copyto(array_ref(), self.initial_data) + self.assertAllEqual(array_ref(), self.initial_data) + self.assertAllEqual( + self.interpreter.get_tensor(self.input0), self.initial_data) + + def testGetTensorAccessor(self): + """Check that get_tensor returns a copy.""" + self.interpreter.set_tensor(self.input0, self.initial_data) + array_initial_copy = self.interpreter.get_tensor(self.input0) + new_value = np.add(1., array_initial_copy) + self.interpreter.set_tensor(self.input0, new_value) + self.assertAllEqual(array_initial_copy, self.initial_data) + self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value) + + def testBase(self): + self.assertTrue(self.interpreter._safe_to_run()) + _ = self.interpreter.tensor(self.input0) + self.assertTrue(self.interpreter._safe_to_run()) + in0 = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + in0b = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + # Now get rid of the buffers so that we can evaluate. + del in0 + del in0b + self.assertTrue(self.interpreter._safe_to_run()) + + def testBaseProtectsFunctions(self): + in0 = self.interpreter.tensor(self.input0)() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.allocate_tensors() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.invoke() + # Now test that we can run + del in0 # this is our only buffer reference, so now it is safe to change + in0safe = self.interpreter.tensor(self.input0) + _ = self.interpreter.allocate_tensors() + del in0safe # make sure in0Safe is held but lint doesn't complain + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 12ab38847dc0f838ae2c6bf80ed80805285e4b8b..634c2a1e1f5005208b4eea5c853a43cccf4d244c 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -14,7 +14,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/core:lib", - "//tensorflow/python:numpy_lib", + "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index e5e5c4fb029d6964fa0f26ae632a2b8e912d1cab..5554d08fa08fdc6ddcb042d12f979164a144e337 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -21,7 +21,14 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/python/lib/core/numpy.h" + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" #if PY_MAJOR_VERSION >= 3 #define PY_TO_CPPSTRING PyBytes_AsStringAndSize @@ -35,6 +42,13 @@ namespace tflite { namespace interpreter_wrapper { namespace { + +// Calls PyArray's initialization to initialize all the API pointers. Note that +// this usage implies only this translation unit can use the pointers. See +// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend +// this further. +void ImportNumpy() { import_array1(); } + std::unique_ptr CreateInterpreter( const tflite::FlatBufferModel* model, const tflite::ops::builtin::BuiltinOpResolver& resolver) { @@ -42,7 +56,7 @@ std::unique_ptr CreateInterpreter( return nullptr; } - tensorflow::ImportNumpy(); + ImportNumpy(); std::unique_ptr interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); @@ -78,6 +92,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_OBJECT; case kTfLiteBool: return NPY_BOOL; + case kTfLiteComplex64: + return NPY_COMPLEX64; case kTfLiteNoType: return -1; } @@ -104,6 +120,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { case NPY_STRING: case NPY_UNICODE: return kTfLiteString; + case NPY_COMPLEX64: + return kTfLiteComplex64; } LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; return kTfLiteNoType; @@ -288,47 +306,93 @@ bool InterpreterWrapper::SetTensor(int i, PyObject* value) { return true; } -PyObject* InterpreterWrapper::GetTensor(int i) const { - if (!interpreter_) { +namespace { + +PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index, + TfLiteTensor** tensor, int* type_num) { + if (!interpreter) { LOG(ERROR) << "Invalid interpreter."; Py_INCREF(Py_None); return Py_None; } - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->inputs().size(); + if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) { + LOG(ERROR) << "Invalid tensor index: " << tensor_index + << " exceeds max tensor index " << interpreter->inputs().size(); Py_INCREF(Py_None); return Py_None; } - const TfLiteTensor* output_tensor = interpreter_->tensor(i); - const int tensor_size = output_tensor->bytes; - if (tensor_size <= 0) { + *tensor = interpreter->tensor(tensor_index); + if ((*tensor)->bytes == 0) { LOG(ERROR) << "Invalid tensor size"; Py_INCREF(Py_None); return Py_None; } - int type_num = TfLiteTypeToPyArrayType(output_tensor->type); - if (type_num == -1) { - LOG(ERROR) << "Unknown tensor type " << output_tensor->type; + *type_num = TfLiteTypeToPyArrayType((*tensor)->type); + if (*type_num == -1) { + LOG(ERROR) << "Unknown tensor type " << (*tensor)->type; + Py_INCREF(Py_None); + return Py_None; + } + + if (!(*tensor)->data.raw) { + LOG(ERROR) << "Tensor data is null."; Py_INCREF(Py_None); return Py_None; } - void* data = malloc(tensor_size); - memcpy(data, output_tensor->data.raw, tensor_size); + return nullptr; +} + +} // namespace - const TfLiteIntArray* output_dims = output_tensor->dims; - std::vector dims(output_dims->data, - output_dims->data + output_dims->size); +PyObject* InterpreterWrapper::GetTensor(int i) const { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + if (PyObject* pynone_or_nullptr = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { + return pynone_or_nullptr; + } + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + // Make a buffer copy but we must tell Numpy It owns that data or else + // it will leak. + void* data = malloc(tensor->bytes); + if (!data) { + LOG(ERROR) << "Malloc to copy tensor failed."; + Py_INCREF(Py_None); + return Py_None; + } + memcpy(data, tensor->data.raw, tensor->bytes); PyObject* np_array = PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); - + PyArray_ENABLEFLAGS(reinterpret_cast(np_array), + NPY_ARRAY_OWNDATA); return PyArray_Return(reinterpret_cast(np_array)); } +PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + if (PyObject* pynone_or_nullptr = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { + return pynone_or_nullptr; + } + + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + PyArrayObject* np_array = + reinterpret_cast(PyArray_SimpleNewFromData( + dims.size(), dims.data(), type_num, tensor->data.raw)); + Py_INCREF(base_object); // SetBaseObject steals, so we need to add. + PyArray_SetBaseObject(np_array, base_object); + return PyArray_Return(np_array); +} + InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( const char* model_path) { std::unique_ptr model = diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index c02aa3804367f787016ef78fc8557005507f051b..681448be20cfc013a0c4d02a6aa549744b976077 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -20,8 +20,8 @@ limitations under the License. #include // Place `` before to avoid build failures in macOS. -#include #include +#include // We forward declare TFLite classes here to avoid exposing them to SWIG. namespace tflite { @@ -58,6 +58,9 @@ class InterpreterWrapper { PyObject* TensorQuantization(int i) const; bool SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; + // Returns a reference to tensor index i as a numpy array. The base_object + // should be the interpreter object providing the memory. + PyObject* tensor(PyObject* base_object, int i); private: InterpreterWrapper(std::unique_ptr model); diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 8315066cd129a137b9159690123ae47bee18c1c8..29a1487c1f468055dde85ef6c2657a50f3d2f32b 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -50,12 +50,14 @@ from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: di from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework.importer import import_graph_def from tensorflow.python.ops.variables import global_variables_initializer from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants +# from tensorflow.python.util.all_util import remove_undocumented class TocoConverter(object): @@ -66,11 +68,11 @@ class TocoConverter(object): Attributes: - inference_type: Target data type of arrays in the output file. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) - inference_input_type: Target data type of input arrays. Allows for a - different type for input arrays in the case of quantization. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) quantized_input_stats: Dict of strings representing input tensor names @@ -130,7 +132,7 @@ class TocoConverter(object): Args: - graph_def: TensorFlow GraphDef. + graph_def: Frozen TensorFlow GraphDef. input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). @@ -176,7 +178,7 @@ class TocoConverter(object): """Creates a TocoConverter class from a file containing a frozen GraphDef. Args: - graph_def_file: Full filepath of file containing TensorFlow GraphDef. + graph_def_file: Full filepath of file containing frozen GraphDef. input_arrays: List of input tensors to freeze graph with. output_arrays: List of output tensors to freeze graph with. input_shapes: Dict of strings representing input tensor names to list of @@ -268,6 +270,48 @@ class TocoConverter(object): return cls( graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) + @classmethod + def from_keras_model_file(cls, + model_file, + input_arrays=None, + input_shapes=None, + output_arrays=None): + """Creates a TocoConverter class from a tf.keras model file. + + Args: + model_file: Full filepath of HDF5 file containing the tf.keras model. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + + Returns: + TocoConverter class. + """ + _keras.backend.clear_session() + _keras.backend.set_learning_phase(False) + keras_model = _keras.models.load_model(model_file) + sess = _keras.backend.get_session() + + # Get input and output tensors. + if input_arrays: + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + else: + input_tensors = keras_model.inputs + + if output_arrays: + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + else: + output_tensors = keras_model.outputs + set_tensor_shapes(input_tensors, input_shapes) + + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + def convert(self): """Converts a TensorFlow GraphDef based on instance variables. @@ -365,7 +409,7 @@ def _is_frozen_graph(sess): Bool. """ for op in sess.graph.get_operations(): - if op.type.startswith("Variable"): + if op.type.startswith("Variable") or op.type.endswith("VariableOp"): return False return True @@ -390,3 +434,5 @@ def _freeze_graph(sess, output_tensors): output_arrays) else: return sess.graph_def + +# remove_undocumented(__name__) diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 8c9d2c1651dd2d0b3cd27cf638c04429e3131efb..ca2af5aaed3ee4f4fce5f0d31eaa61df0e11f364 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function import os +import tempfile import numpy as np from tensorflow.contrib.lite.python import lite from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -267,7 +269,8 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(num_items_graphviz_video > num_items_graphviz) def testInferenceInputType(self): - in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() @@ -286,14 +289,13 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertEqual('Placeholder', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) + self.assertEqual((1., 0.), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('add', output_details[0]['name']) - self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) def testDefaultRangesStats(self): in_tensor = array_ops.placeholder( @@ -618,5 +620,279 @@ class FromSavedModelTest(test_util.TensorFlowTestCase): self.assertTrue(tflite_model) +class FromKerasFile(test_util.TensorFlowTestCase): + + def setUp(self): + keras.backend.clear_session() + + def _getSequentialModel(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + try: + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + finally: + os.close(fd) + return keras_file + + def testSequentialModel(self): + """Test a Sequential tf.keras model with default inputs.""" + keras_file = self._getSequentialModel() + + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testSequentialModelInputArray(self): + """Test a Sequential tf.keras model testing input arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid input array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['invalid-input']) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['dense_input']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testSequentialModelInputShape(self): + """Test a Sequential tf.keras model testing input shapes argument.""" + keras_file = self._getSequentialModel() + + # Passing in shape of invalid input array has no impact as long as all input + # arrays have a shape. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'invalid-input': [2, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Passing in shape of valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'dense_input': [2, 3]}) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + # Check input shape from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertTrue(([2, 3] == input_details[0]['shape']).all()) + + def testSequentialModelOutputArray(self): + """Test a Sequential tf.keras model testing output arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid output array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['invalid-output']) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) + + # Valid output array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['time_distributed/Reshape_1']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testFunctionalModel(self): + """Test a Functional tf.keras model with default inputs.""" + inputs = keras.layers.Input(shape=(3,), name='input') + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFunctionalModelMultipleInputs(self): + """Test a Functional tf.keras model with multiple inputs and outputs.""" + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.mae], + loss_weights=[1., 0.5]) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + + model.predict([input_a_np, input_b_np], batch_size=5) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('input_a', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('input_b', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(2, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + self.assertEqual('dropout/Identity', output_details[1]['name']) + self.assertEqual(np.float32, output_details[1]['dtype']) + self.assertTrue(([1, 4] == output_details[1]['shape']).all()) + self.assertEqual((0., 0.), output_details[1]['quantization']) + + def testFunctionalSequentialModel(self): + """Test a Functional tf.keras model containing a Sequential model.""" + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model = keras.models.Model(model.input, model.output) + + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index f497533bed054d260aefc7b3fe67ae655c7cbcda..9bd1f4f76ee693414a8515a5bd2567001b53e2ea 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -23,19 +23,15 @@ import os import sys from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.platform import app -def _parse_array(values): +def _parse_array(values, type_fn=str): if values: - return values.split(",") - - -def _parse_int_array(values): - if values: - return [int(val) for val in values.split(",")] + return [type_fn(val) for val in values.split(",") if val] def _parse_set(values): @@ -57,7 +53,8 @@ def _get_toco_converter(flags): input_shapes = None if flags.input_shapes: input_shapes_list = [ - _parse_int_array(shape) for shape in flags.input_shapes.split(":") + _parse_array(shape, type_fn=int) + for shape in flags.input_shapes.split(":") ] input_shapes = dict(zip(input_arrays, input_shapes_list)) output_arrays = _parse_array(flags.output_arrays) @@ -77,6 +74,9 @@ def _get_toco_converter(flags): converter_kwargs["saved_model_dir"] = flags.saved_model_dir converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) converter_kwargs["signature_key"] = flags.saved_model_signature_key + elif flags.keras_model_file: + converter_fn = lite.TocoConverter.from_keras_model_file + converter_kwargs["model_file"] = flags.keras_model_file return converter_fn(**converter_kwargs) @@ -103,9 +103,9 @@ def _convert_model(flags): if flags.mean_values and flags.std_dev_values: input_arrays = converter.get_input_arrays() - std_dev_values = _parse_int_array(flags.std_dev_values) - mean_values = _parse_int_array(flags.mean_values) - quant_stats = zip(mean_values, std_dev_values) + std_dev_values = _parse_array(flags.std_dev_values, type_fn=int) + mean_values = _parse_array(flags.mean_values, type_fn=int) + quant_stats = list(zip(mean_values, std_dev_values)) if ((not flags.input_arrays and len(input_arrays) > 1) or (len(input_arrays) != len(quant_stats))): raise ValueError("Mismatching --input_arrays, --std_dev_values, and " @@ -130,6 +130,9 @@ def _convert_model(flags): if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops if flags.quantize_weights: + if flags.inference_type == lite_constants.QUANTIZED_UINT8: + raise ValueError("--quantized_weights is not supported with " + "--inference_type=QUANTIZED_UINT8") converter.quantize_weights = flags.quantize_weights if flags.dump_graphviz_dir: converter.dump_graphviz_dir = flags.dump_graphviz_dir @@ -200,6 +203,9 @@ def _check_flags(flags, unparsed): raise ValueError("--default_ranges_min and --default_ranges_max must be " "used together") + if flags.dump_graphviz_video and not flags.dump_graphviz: + raise ValueError("--dump_graphviz_video must be used with --dump_graphviz") + def run_main(_): """Main in toco_convert.py.""" @@ -219,11 +225,15 @@ def run_main(_): input_file_group.add_argument( "--graph_def_file", type=str, - help="Full filepath of file containing TensorFlow GraphDef.") + help="Full filepath of file containing frozen TensorFlow GraphDef.") input_file_group.add_argument( "--saved_model_dir", type=str, help="Full filepath of directory containing the SavedModel.") + input_file_group.add_argument( + "--keras_model_file", + type=str, + help="Full filepath of HDF5 file containing tf.Keras model.") # Model format flags. parser.add_argument( @@ -235,13 +245,13 @@ def run_main(_): "--inference_type", type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], - help="Target data type of arrays in the output file.") + help="Target data type of real-number arrays in the output file.") parser.add_argument( "--inference_input_type", type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], - help=("Target data type of input arrays. Allows for a different type for " - "input arrays in the case of quantization.")) + help=("Target data type of real-number input arrays. Allows for a " + "different type for input arrays in the case of quantization.")) # Input and output arrays flags. parser.add_argument( @@ -275,12 +285,12 @@ def run_main(_): "--std_dev_values", type=str, help=("Standard deviation of training data for each input tensor, " - "comma-separated. Used for quantization. (default None)")) + "comma-separated integers. Used for quantization. (default None)")) parser.add_argument( "--mean_values", type=str, - help=("Mean of training data for each input tensor, comma-separated. " - "Used for quantization. (default None)")) + help=("Mean of training data for each input tensor, comma-separated " + "integers. Used for quantization. (default None)")) parser.add_argument( "--default_ranges_min", type=int, diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index c7b955a1659cf65ed0e0233b8b75db60887de34c..15fb8bbdb8f100201750faf706eb45b697319dfb 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -35,6 +35,7 @@ enum TensorType : byte { STRING = 5, BOOL = 6, INT16 = 7, + COMPLEX64 = 8, } // Parameters for converting a quantized tensor back to float. Given a @@ -154,6 +155,11 @@ enum BuiltinOperator : byte { EQUAL = 71, NOT_EQUAL = 72, LOG = 73, + SUM=74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, } // Options for the builtin operators. @@ -184,7 +190,7 @@ union BuiltinOptions { BatchToSpaceNDOptions, SpaceToBatchNDOptions, TransposeOptions, - MeanOptions, + ReducerOptions, SubOptions, DivOptions, SqueezeOptions, @@ -212,6 +218,8 @@ union BuiltinOptions { ExpandDimsOptions, EqualOptions, NotEqualOptions, + ShapeOptions, + PowOptions, } enum Padding : byte { SAME, VALID } @@ -289,9 +297,18 @@ table BidirectionalSequenceRNNOptions { fused_activation_function:ActivationFunctionType; } +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; } table SoftmaxOptions { @@ -411,7 +428,7 @@ table TransposeOptions { table ExpOptions { } -table MeanOptions { +table ReducerOptions { keep_dims: bool; } @@ -492,6 +509,14 @@ table EqualOptions { table NotEqualOptions { } +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table PowOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 81d4574da7f5025c4dd246b5fc8fe74b7d8b15ae..fe0ff9a7a5ba0764475f4a7c14cd875b3cdb2aa8 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -127,8 +127,8 @@ struct TransposeOptionsT; struct ExpOptions; struct ExpOptionsT; -struct MeanOptions; -struct MeanOptionsT; +struct ReducerOptions; +struct ReducerOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; @@ -193,6 +193,12 @@ struct EqualOptionsT; struct NotEqualOptions; struct NotEqualOptionsT; +struct ShapeOptions; +struct ShapeOptionsT; + +struct PowOptions; +struct PowOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -217,11 +223,12 @@ enum TensorType { TensorType_STRING = 5, TensorType_BOOL = 6, TensorType_INT16 = 7, + TensorType_COMPLEX64 = 8, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_INT16 + TensorType_MAX = TensorType_COMPLEX64 }; -inline TensorType (&EnumValuesTensorType())[8] { +inline TensorType (&EnumValuesTensorType())[9] { static TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -230,7 +237,8 @@ inline TensorType (&EnumValuesTensorType())[8] { TensorType_INT64, TensorType_STRING, TensorType_BOOL, - TensorType_INT16 + TensorType_INT16, + TensorType_COMPLEX64 }; return values; } @@ -245,6 +253,7 @@ inline const char **EnumNamesTensorType() { "STRING", "BOOL", "INT16", + "COMPLEX64", nullptr }; return names; @@ -329,11 +338,16 @@ enum BuiltinOperator { BuiltinOperator_EQUAL = 71, BuiltinOperator_NOT_EQUAL = 72, BuiltinOperator_LOG = 73, + BuiltinOperator_SUM = 74, + BuiltinOperator_SQRT = 75, + BuiltinOperator_RSQRT = 76, + BuiltinOperator_SHAPE = 77, + BuiltinOperator_POW = 78, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_LOG + BuiltinOperator_MAX = BuiltinOperator_POW }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -407,7 +421,12 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] { BuiltinOperator_EXPAND_DIMS, BuiltinOperator_EQUAL, BuiltinOperator_NOT_EQUAL, - BuiltinOperator_LOG + BuiltinOperator_LOG, + BuiltinOperator_SUM, + BuiltinOperator_SQRT, + BuiltinOperator_RSQRT, + BuiltinOperator_SHAPE, + BuiltinOperator_POW }; return values; } @@ -488,6 +507,11 @@ inline const char **EnumNamesBuiltinOperator() { "EQUAL", "NOT_EQUAL", "LOG", + "SUM", + "SQRT", + "RSQRT", + "SHAPE", + "POW", nullptr }; return names; @@ -526,7 +550,7 @@ enum BuiltinOptions { BuiltinOptions_BatchToSpaceNDOptions = 24, BuiltinOptions_SpaceToBatchNDOptions = 25, BuiltinOptions_TransposeOptions = 26, - BuiltinOptions_MeanOptions = 27, + BuiltinOptions_ReducerOptions = 27, BuiltinOptions_SubOptions = 28, BuiltinOptions_DivOptions = 29, BuiltinOptions_SqueezeOptions = 30, @@ -554,11 +578,13 @@ enum BuiltinOptions { BuiltinOptions_ExpandDimsOptions = 52, BuiltinOptions_EqualOptions = 53, BuiltinOptions_NotEqualOptions = 54, + BuiltinOptions_ShapeOptions = 55, + BuiltinOptions_PowOptions = 56, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_NotEqualOptions + BuiltinOptions_MAX = BuiltinOptions_PowOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -587,7 +613,7 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { BuiltinOptions_BatchToSpaceNDOptions, BuiltinOptions_SpaceToBatchNDOptions, BuiltinOptions_TransposeOptions, - BuiltinOptions_MeanOptions, + BuiltinOptions_ReducerOptions, BuiltinOptions_SubOptions, BuiltinOptions_DivOptions, BuiltinOptions_SqueezeOptions, @@ -614,7 +640,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { BuiltinOptions_TileOptions, BuiltinOptions_ExpandDimsOptions, BuiltinOptions_EqualOptions, - BuiltinOptions_NotEqualOptions + BuiltinOptions_NotEqualOptions, + BuiltinOptions_ShapeOptions, + BuiltinOptions_PowOptions }; return values; } @@ -648,7 +676,7 @@ inline const char **EnumNamesBuiltinOptions() { "BatchToSpaceNDOptions", "SpaceToBatchNDOptions", "TransposeOptions", - "MeanOptions", + "ReducerOptions", "SubOptions", "DivOptions", "SqueezeOptions", @@ -676,6 +704,8 @@ inline const char **EnumNamesBuiltinOptions() { "ExpandDimsOptions", "EqualOptions", "NotEqualOptions", + "ShapeOptions", + "PowOptions", nullptr }; return names; @@ -794,8 +824,8 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; }; -template<> struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = BuiltinOptions_MeanOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; }; template<> struct BuiltinOptionsTraits { @@ -906,6 +936,14 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1145,13 +1183,13 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeOptions ? reinterpret_cast(value) : nullptr; } - MeanOptionsT *AsMeanOptions() { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + ReducerOptionsT *AsReducerOptions() { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } - const MeanOptionsT *AsMeanOptions() const { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + const ReducerOptionsT *AsReducerOptions() const { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } SubOptionsT *AsSubOptions() { return type == BuiltinOptions_SubOptions ? @@ -1369,6 +1407,22 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_NotEqualOptions ? reinterpret_cast(value) : nullptr; } + ShapeOptionsT *AsShapeOptions() { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + const ShapeOptionsT *AsShapeOptions() const { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + PowOptionsT *AsPowOptions() { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } + const PowOptionsT *AsPowOptions() const { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -1476,6 +1530,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { return EnumNamesLSHProjectionType()[index]; } +enum FullyConnectedOptionsWeightsFormat { + FullyConnectedOptionsWeightsFormat_DEFAULT = 0, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 = 1, + FullyConnectedOptionsWeightsFormat_MIN = FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 +}; + +inline FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] { + static FullyConnectedOptionsWeightsFormat values[] = { + FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 + }; + return values; +} + +inline const char **EnumNamesFullyConnectedOptionsWeightsFormat() { + static const char *names[] = { + "DEFAULT", + "SHUFFLED4x16INT8", + nullptr + }; + return names; +} + +inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e) { + const size_t index = static_cast(e); + return EnumNamesFullyConnectedOptionsWeightsFormat()[index]; +} + enum LSTMKernelType { LSTMKernelType_FULL = 0, LSTMKernelType_BASIC = 1, @@ -2528,22 +2611,29 @@ flatbuffers::Offset CreateBidirectionalSequence struct FullyConnectedOptionsT : public flatbuffers::NativeTable { typedef FullyConnectedOptions TableType; ActivationFunctionType fused_activation_function; + FullyConnectedOptionsWeightsFormat weights_format; FullyConnectedOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(ActivationFunctionType_NONE), + weights_format(FullyConnectedOptionsWeightsFormat_DEFAULT) { } }; struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef FullyConnectedOptionsT NativeTableType; enum { - VT_FUSED_ACTIVATION_FUNCTION = 4 + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_WEIGHTS_FORMAT = 6 }; ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + FullyConnectedOptionsWeightsFormat weights_format() const { + return static_cast(GetField(VT_WEIGHTS_FORMAT, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_WEIGHTS_FORMAT) && verifier.EndTable(); } FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2557,6 +2647,9 @@ struct FullyConnectedOptionsBuilder { void add_fused_activation_function(ActivationFunctionType fused_activation_function) { fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_weights_format(FullyConnectedOptionsWeightsFormat weights_format) { + fbb_.AddElement(FullyConnectedOptions::VT_WEIGHTS_FORMAT, static_cast(weights_format), 0); + } explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2571,8 +2664,10 @@ struct FullyConnectedOptionsBuilder { inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT) { FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_weights_format(weights_format); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -3839,16 +3934,16 @@ inline flatbuffers::Offset CreateExpOptions( flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -struct MeanOptionsT : public flatbuffers::NativeTable { - typedef MeanOptions TableType; +struct ReducerOptionsT : public flatbuffers::NativeTable { + typedef ReducerOptions TableType; bool keep_dims; - MeanOptionsT() + ReducerOptionsT() : keep_dims(false) { } }; -struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef MeanOptionsT NativeTableType; +struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReducerOptionsT NativeTableType; enum { VT_KEEP_DIMS = 4 }; @@ -3860,38 +3955,38 @@ struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_KEEP_DIMS) && verifier.EndTable(); } - MeanOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ReducerOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -struct MeanOptionsBuilder { +struct ReducerOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_keep_dims(bool keep_dims) { - fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); + fbb_.AddElement(ReducerOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); } - explicit MeanOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit ReducerOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - MeanOptionsBuilder &operator=(const MeanOptionsBuilder &); - flatbuffers::Offset Finish() { + ReducerOptionsBuilder &operator=(const ReducerOptionsBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateMeanOptions( +inline flatbuffers::Offset CreateReducerOptions( flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) { - MeanOptionsBuilder builder_(_fbb); + ReducerOptionsBuilder builder_(_fbb); builder_.add_keep_dims(keep_dims); return builder_.Finish(); } -flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SqueezeOptionsT : public flatbuffers::NativeTable { typedef SqueezeOptions TableType; @@ -4923,6 +5018,100 @@ inline flatbuffers::Offset CreateNotEqualOptions( flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ShapeOptionsT : public flatbuffers::NativeTable { + typedef ShapeOptions TableType; + TensorType out_type; + ShapeOptionsT() + : out_type(TensorType_FLOAT32) { + } +}; + +struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ShapeOptionsT NativeTableType; + enum { + VT_OUT_TYPE = 4 + }; + TensorType out_type() const { + return static_cast(GetField(VT_OUT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUT_TYPE) && + verifier.EndTable(); + } + ShapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ShapeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_out_type(TensorType out_type) { + fbb_.AddElement(ShapeOptions::VT_OUT_TYPE, static_cast(out_type), 0); + } + explicit ShapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ShapeOptionsBuilder &operator=(const ShapeOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateShapeOptions( + flatbuffers::FlatBufferBuilder &_fbb, + TensorType out_type = TensorType_FLOAT32) { + ShapeOptionsBuilder builder_(_fbb); + builder_.add_out_type(out_type); + return builder_.Finish(); +} + +flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PowOptionsT : public flatbuffers::NativeTable { + typedef PowOptions TableType; + PowOptionsT() { + } +}; + +struct PowOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PowOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + PowOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PowOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit PowOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PowOptionsBuilder &operator=(const PowOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePowOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + PowOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -5134,8 +5323,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeOptions *builtin_options_as_TransposeOptions() const { return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; } - const MeanOptions *builtin_options_as_MeanOptions() const { - return builtin_options_type() == BuiltinOptions_MeanOptions ? static_cast(builtin_options()) : nullptr; + const ReducerOptions *builtin_options_as_ReducerOptions() const { + return builtin_options_type() == BuiltinOptions_ReducerOptions ? static_cast(builtin_options()) : nullptr; } const SubOptions *builtin_options_as_SubOptions() const { return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; @@ -5218,6 +5407,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const NotEqualOptions *builtin_options_as_NotEqualOptions() const { return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; } + const ShapeOptions *builtin_options_as_ShapeOptions() const { + return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast(builtin_options()) : nullptr; + } + const PowOptions *builtin_options_as_PowOptions() const { + return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -5353,8 +5548,8 @@ template<> inline const TransposeOptions *Operator::builtin_options_as inline const MeanOptions *Operator::builtin_options_as() const { - return builtin_options_as_MeanOptions(); +template<> inline const ReducerOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReducerOptions(); } template<> inline const SubOptions *Operator::builtin_options_as() const { @@ -5465,6 +5660,14 @@ template<> inline const NotEqualOptions *Operator::builtin_options_as inline const ShapeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ShapeOptions(); +} + +template<> inline const PowOptions *Operator::builtin_options_as() const { + return builtin_options_as_PowOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -6244,6 +6447,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl (void)_o; (void)_resolver; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = weights_format(); _o->weights_format = _e; }; } inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6255,9 +6459,11 @@ inline flatbuffers::Offset CreateFullyConnectedOptions(fl (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; + auto _weights_format = _o->weights_format; return tflite::CreateFullyConnectedOptions( _fbb, - _fused_activation_function); + _fused_activation_function, + _weights_format); } inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6864,28 +7070,28 @@ inline flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferB _fbb); } -inline MeanOptionsT *MeanOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new MeanOptionsT(); +inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReducerOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MeanOptions::UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void ReducerOptions::UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; { auto _e = keep_dims(); _o->keep_dims = _e; }; } -inline flatbuffers::Offset MeanOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateMeanOptions(_fbb, _o, _rehasher); +inline flatbuffers::Offset ReducerOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReducerOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MeanOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReducerOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _keep_dims = _o->keep_dims; - return tflite::CreateMeanOptions( + return tflite::CreateReducerOptions( _fbb, _keep_dims); } @@ -7415,6 +7621,55 @@ inline flatbuffers::Offset CreateNotEqualOptions(flatbuffers::F _fbb); } +inline ShapeOptionsT *ShapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ShapeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ShapeOptions::UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = out_type(); _o->out_type = _e; }; +} + +inline flatbuffers::Offset ShapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateShapeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ShapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _out_type = _o->out_type; + return tflite::CreateShapeOptions( + _fbb, + _out_type); +} + +inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PowOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void PowOptions::UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset PowOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePowOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PowOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePowOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7708,8 +7963,8 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SubOptions: { @@ -7820,6 +8075,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7942,8 +8205,8 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SubOptions: { @@ -8054,6 +8317,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -8164,9 +8435,9 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeOptions(_fbb, ptr, _rehasher).Union(); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); - return CreateMeanOptions(_fbb, ptr, _rehasher).Union(); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); + return CreateReducerOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SubOptions: { auto ptr = reinterpret_cast(value); @@ -8276,6 +8547,14 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateShapeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + return CreatePowOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -8386,8 +8665,8 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeOptionsT(*reinterpret_cast(u.value)); break; } - case BuiltinOptions_MeanOptions: { - value = new MeanOptionsT(*reinterpret_cast(u.value)); + case BuiltinOptions_ReducerOptions: { + value = new ReducerOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SubOptions: { @@ -8498,6 +8777,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new NotEqualOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_ShapeOptions: { + value = new ShapeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PowOptions: { + value = new PowOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -8635,8 +8922,8 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -8775,6 +9062,16 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index f5e25784fa17209af7cfb06d32aeea2b9b947196..1360f1a27383a709accc1abbd723601854d48a12 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -137,7 +137,7 @@ def toco_options(data_types, Returns: the options in a string. """ - shape_str = ":".join([",".join(str(y) for y in x) for x in shapes]) + shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x]) inference_type = "FLOAT" # TODO(ahentz): if we get multi-input quantization to work we need this # to change @@ -705,7 +705,7 @@ def make_constant_tests(zip_path): def make_binary_op_tests(zip_path, binary_operator): - """Make a set of tests to do add with and without broadcast.""" + """Make a set of tests to do binary ops with and without broadcast.""" # These parameters are split because we don't support broadcasting. test_parameters = [{ @@ -834,6 +834,12 @@ def make_mean_tests(zip_path): return make_reduce_tests(tf.reduce_mean)(zip_path) +def make_sum_tests(zip_path): + """Make a set of tests to do sum.""" + + return make_reduce_tests(tf.reduce_sum)(zip_path) + + def make_exp_tests(zip_path): """Make a set of tests to do exp.""" @@ -984,6 +990,10 @@ def make_mul_tests(zip_path): make_binary_op_tests(zip_path, tf.multiply) +def make_pow_tests(zip_path): + make_binary_op_tests(zip_path, tf.pow) + + def make_gather_tests(zip_path): """Make a set of tests to do gather.""" @@ -1539,6 +1549,32 @@ def make_reshape_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_shape_tests(zip_path): + """Make a set of tests to do shape.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]], + "out_type": [tf.int32, tf.int64], + }] + + def build_graph(parameters): + """Build the topk op testing graph.""" + # Note that we intentionally leave out the shape from the input placeholder + # to prevent the Shape operation from being optimized out during conversion. + input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") + out = tf.shape(input_value, out_type=parameters["out_type"]) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_resize_bilinear_tests(zip_path): """Make a set of tests to do resize_bilinear.""" @@ -2431,7 +2467,7 @@ def _make_elementwise_tests(op): }] def build_graph(parameters): - """Build the sin op testing graph.""" + """Build the unary op testing graph.""" input_value = tf.placeholder( dtype=parameters["input_dtype"], name="input1", @@ -2460,6 +2496,16 @@ def make_log_tests(zip_path): return _make_elementwise_tests(tf.log)(zip_path) +def make_sqrt_tests(zip_path): + """Make a set of tests to do sqrt.""" + return _make_elementwise_tests(tf.sqrt)(zip_path) + + +def make_rsqrt_tests(zip_path): + """Make a set of tests to do 1/sqrt.""" + return _make_elementwise_tests(tf.rsqrt)(zip_path) + + def make_where_tests(zip_path): """Make a set of tests to do where.""" diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 8a59d756f8dbbcefc930b5285c1ced8ce6b08845..a86cd5c6ccfc980e8b3a83526714fa11dcc3a4a9 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -52,9 +52,6 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - // Add only supports float32. (and "constant" tests use Add) - {R"(^\/add_a.*int32)", "68808744"}, - {R"(^\/constant.*int32)", "68808744"}, {R"(^\/mul.*int32)", "68808744"}, {R"(^\/div.*int32)", "68808744"}, {R"(^\/sub.*int32)", "68808744"}, diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 54edfdfb1df3f45b4823a36503c01551348ead6c..4d08fb545801521213890a4f5a9b010de57b27cd 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -288,8 +288,8 @@ void TfLiteDriver::ResetLSTMStateTensors() { interpreter_->ResetVariableTensorsToZero(); // Below is a workaround for initializing state tensors for LSTM. - // TODO(ycling): Refactoring and find a better way to initialize state - // tensors. Maybe write the reset instructions into the test data. + // TODO(ycling): Remove the code below after nobody is using the 18-inputs + // definition. for (auto node_index : interpreter_->execution_plan()) { const auto& node_and_reg = interpreter_->node_and_registration(node_index); const auto& node = node_and_reg->first; @@ -299,7 +299,7 @@ void TfLiteDriver::ResetLSTMStateTensors() { const auto* params = reinterpret_cast(node.builtin_data); if (params->kernel_type == kTfLiteLSTMFullKernel && - node.outputs->size >= 2) { + node.inputs->size == 18 && node.outputs->size >= 2) { // The first 2 outputs of LSTM are state tensors. for (int i = 0; i < 2; ++i) { int node_index = node.outputs->data[i]; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index dd05c484fabf4d87dc12b39940a71677af4023e2..f74fc45330e825a41c0ec9d93033fea60bb4de09 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -143,7 +143,6 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", - "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", @@ -169,41 +168,6 @@ cc_library( ], ) -cc_library( - name = "toco_saved_model", - srcs = [ - "toco_saved_model.cc", - ], - hdrs = [ - "toco_saved_model.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":model_cmdline_flags", - ":model_flags_proto_cc", - ":toco_flags_proto_cc", - ":types_proto_cc", - "//tensorflow/cc/tools:freeze_saved_model", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "toco_saved_model_test", - srcs = ["toco_saved_model_test.cc"], - deps = [ - ":model_cmdline_flags", - ":toco_cmdline_flags", - ":toco_saved_model", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "graph_transformations", srcs = [ @@ -221,7 +185,6 @@ cc_library( "graph_transformations/drop_im2col_arrays.cc", "graph_transformations/ensure_bias_vectors.cc", "graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc", - "graph_transformations/experimental_shuffle_fc_weights.cc", "graph_transformations/fuse_activation_functions.cc", "graph_transformations/fuse_binary_into_following_affine.cc", "graph_transformations/fuse_binary_into_preceding_affine.cc", @@ -296,6 +259,7 @@ cc_library( "graph_transformations/resolve_tensorflow_merge.cc", "graph_transformations/resolve_tensorflow_switch.cc", "graph_transformations/resolve_transpose_attributes.cc", + "graph_transformations/shuffle_fc_weights.cc", "graph_transformations/unfuse_activation_functions.cc", "graph_transformations/unpartition_embedding_lookup.cc", "graph_transformations/unroll_batch_matmul.cc", @@ -431,7 +395,6 @@ tf_cc_binary( ":toco_cmdline_flags", ":toco_flags_proto_cc", ":toco_port", - ":toco_saved_model", ":toco_tooling", ":types_proto_cc", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md index 522e260ad2a14c5f8e080c0a0f538f4192b7ed2d..2db6a627ab59604a99cafe3b38df08b70092d989 100644 --- a/tensorflow/contrib/lite/toco/README.md +++ b/tensorflow/contrib/lite/toco/README.md @@ -17,11 +17,12 @@ Usage information is given in these documents: Once an application developer has a trained TensorFlow model, TOCO will accept that model and generate a TensorFlow Lite [FlatBuffer](https://google.github.io/flatbuffers/) file. TOCO currently supports -[SavedModels](https://www.tensorflow.org/programmers_guide/saved_model#using_savedmodel_with_estimators) -and frozen graphs (models generated via -[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)). -The TensorFlow Lite FlatBuffer file can be shipped to client devices, generally -mobile devices, where the TensorFlow Lite interpreter handles them on-device. -This flow is represented in the diagram below. +[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), +frozen graphs (models generated via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), +and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped +to client devices, generally mobile devices, where the TensorFlow Lite +interpreter handles them on-device. This flow is represented in the diagram +below. ![drawing](g3doc/toco_landscape.svg) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 9f5ca66d050f0ead9b8856c77dba8d9bbd182d10..aef35ad490656c09a7d7314aa033bc985b3af661 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -21,13 +21,13 @@ limitations under the License. #include #include #include +#include "tensorflow/contrib/lite/toco/toco_port.h" #if defined(PLATFORM_GOOGLE) #include "strings/split.h" +#include "strings/strip.h" #endif #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" -#include "tensorflow/cc/saved_model/tag_constants.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" namespace toco { @@ -145,8 +145,10 @@ class Arg final { } string outer_member_copy = outer_member; absl::StripAsciiWhitespace(&outer_member); - if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; - if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; + if (!strings::TryStripPrefixString(outer_member, "{", &outer_member)) + return false; + if (!strings::TryStripSuffixString(outer_member, "}", &outer_member)) + return false; const std::vector inner_fields_vector = absl::StrSplit(outer_member, ','); @@ -223,7 +225,7 @@ struct ParsedTocoFlags { Arg output_file; Arg input_format = Arg("TENSORFLOW_GRAPHDEF"); Arg output_format = Arg("TFLITE"); - Arg savedmodel_tagset = Arg(tensorflow::kSavedModelTagServe); + Arg savedmodel_tagset; // TODO(aselle): command_line_flags doesn't support doubles Arg default_ranges_min = Arg(0.); Arg default_ranges_max = Arg(0.); diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 878bda36ef3900d6d8c509aca40cee834cefe514..6877fb237c0514a972589ac0301647104f5ed7ed 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -227,7 +227,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties GetPropertiesForOperator(const Operator& op) { NodeProperties node_properties; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { node_properties.label = static_cast(op).tensorflow_op; } else { diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 6e5e0d013750c8669f73003fb9ee861bb4aecb2f..6be6b25f9318deb08bd427d5e3166909fae8f3ea 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -145,7 +145,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -162,7 +162,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -178,7 +178,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -199,7 +199,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -222,7 +222,7 @@ void ConvertIntTensorConst(const Model& model, const string& name, } CHECK(model.HasArray(name)); const auto& array = model.GetArray(name); - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -245,7 +245,7 @@ void CreateIntTensorConst(const string& name, const std::vector& data, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -268,7 +268,7 @@ void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -286,7 +286,7 @@ void CreateDummyConcatDimTensorConst(const string& name, int dim, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -301,7 +301,7 @@ void CreateReshapeShapeTensorConst(const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -341,7 +341,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, conv_output += "/conv"; } - auto* conv2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); conv2d_op->set_op("Conv2D"); conv2d_op->set_name(conv_output); *conv2d_op->add_input() = src_op.inputs[0]; @@ -377,7 +377,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, (*conv2d_op->mutable_attr())["padding"].set_s(padding); if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(conv_output); @@ -409,7 +409,7 @@ void ConvertDepthwiseConvOperator(const Model& model, conv_output += "/conv"; } - auto* dc2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node(); dc2d_op->set_op("DepthwiseConv2dNative"); dc2d_op->set_name(conv_output); *dc2d_op->add_input() = src_op.inputs[0]; @@ -457,7 +457,7 @@ void ConvertDepthwiseConvOperator(const Model& model, (*dc2d_op->mutable_attr())["padding"].set_s(padding); if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(conv_output); @@ -482,7 +482,7 @@ void ConvertDepthwiseConvOperator(const Model& model, void ConvertTransposeConvOperator(const Model& model, const TransposeConvOperator& src_op, GraphDef* tensorflow_graph) { - auto* conv2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); conv2d_op->set_op("Conv2DBackpropInput"); conv2d_op->set_name(src_op.outputs[0]); *conv2d_op->add_input() = src_op.inputs[0]; @@ -514,7 +514,7 @@ void ConvertTransposeConvOperator(const Model& model, void ConvertDepthToSpaceOperator(const Model& model, const DepthToSpaceOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("DepthToSpace"); op->set_name(src_op.outputs[0]); *op->add_input() = src_op.inputs[0]; @@ -525,7 +525,7 @@ void ConvertDepthToSpaceOperator(const Model& model, void ConvertSpaceToDepthOperator(const Model& model, const SpaceToDepthOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("SpaceToDepth"); op->set_name(src_op.outputs[0]); *op->add_input() = src_op.inputs[0]; @@ -546,7 +546,7 @@ void ConvertFullyConnectedOperator(const Model& model, CHECK_EQ(fc_weights_shape.dimensions_count(), 2); CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, tensorflow_graph); - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); reshape_op->add_input(src_op.inputs[0]); @@ -568,7 +568,7 @@ void ConvertFullyConnectedOperator(const Model& model, const string transpose_perm = AvailableArrayName(model, transpose_output + "/perm"); CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph); - auto transpose_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); transpose_op->set_op("Transpose"); transpose_op->set_name(transpose_output); *transpose_op->add_input() = src_op.inputs[1]; @@ -577,7 +577,7 @@ void ConvertFullyConnectedOperator(const Model& model, GetTensorFlowDataType(model, src_op.inputs[1])); (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32); - auto* matmul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); matmul_op->set_name(matmul_output); *matmul_op->add_input() = reshape_output; @@ -590,7 +590,7 @@ void ConvertFullyConnectedOperator(const Model& model, // Add the bias, if it exists. if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(matmul_output); @@ -615,7 +615,7 @@ void ConvertFullyConnectedOperator(const Model& model, void ConvertAddOperator(const Model& model, const AddOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("Add"); add_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -626,7 +626,7 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op, void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("AddN"); add_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { @@ -638,7 +638,7 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, void ConvertMulOperator(const Model& model, const MulOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("Mul"); add_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -649,7 +649,7 @@ void ConvertMulOperator(const Model& model, const MulOperator& src_op, void ConvertReluOperator(const ReluOperator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -662,7 +662,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, const string min_bounds = src_op.outputs[0] + "/min_bounds"; const string max_output = src_op.outputs[0] + "/max_output"; - auto* max_bounds_const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node(); max_bounds_const_op->set_op("Const"); max_bounds_const_op->set_name(max_bounds); (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -671,7 +671,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, max_bounds_const_op_tensor->set_dtype(DT_FLOAT); max_bounds_const_op_tensor->add_float_val(-1.0f); - auto* min_bounds_const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node(); min_bounds_const_op->set_op("Const"); min_bounds_const_op->set_name(min_bounds); (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -680,14 +680,14 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, min_bounds_const_op_tensor->set_dtype(DT_FLOAT); min_bounds_const_op_tensor->add_float_val(1.0f); - auto* max_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); max_op->set_op("Maximum"); max_op->set_name(max_output); *max_op->add_input() = src_op.inputs[0]; *max_op->add_input() = max_bounds; (*max_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* min_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); min_op->set_op("Minimum"); min_op->set_name(src_op.outputs[0]); *min_op->add_input() = max_output; @@ -697,7 +697,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, void ConvertRelu6Operator(const Relu6Operator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu6"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -705,7 +705,7 @@ void ConvertRelu6Operator(const Relu6Operator& src_op, } void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("Log"); op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -715,7 +715,7 @@ void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) { void ConvertLogisticOperator(const LogisticOperator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Sigmoid"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -724,7 +724,7 @@ void ConvertLogisticOperator(const LogisticOperator& src_op, void ConvertTanhOperator(const TanhOperator& src_op, GraphDef* tensorflow_graph) { - auto* tanh_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node(); tanh_op->set_op("Tanh"); tanh_op->set_name(src_op.outputs[0]); *tanh_op->add_input() = src_op.inputs[0]; @@ -735,8 +735,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -745,7 +744,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, const string softmax_size = src_op.outputs[0] + "/softmax_insert_size"; softmax_input = reshape_output; - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); *reshape_op->add_input() = src_op.inputs[0]; @@ -762,7 +761,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); } - auto* softmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node(); softmax_op->set_op("Softmax"); softmax_op->set_name(src_op.outputs[0]); *softmax_op->add_input() = softmax_input; @@ -776,8 +775,7 @@ void ConvertLogSoftmaxOperator(const Model& model, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -787,7 +785,7 @@ void ConvertLogSoftmaxOperator(const Model& model, const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size"; softmax_input = reshape_output; - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); *reshape_op->add_input() = src_op.inputs[0]; @@ -804,7 +802,7 @@ void ConvertLogSoftmaxOperator(const Model& model, CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); } - auto* log_softmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node(); log_softmax_op->set_op("LogSoftmax"); log_softmax_op->set_name(src_op.outputs[0]); *log_softmax_op->add_input() = softmax_input; @@ -819,7 +817,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, const string rsqrt_output = src_op.outputs[0] + "/rsqrt"; const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled"; - auto* sum_reduction_indices_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node(); sum_reduction_indices_op->set_op("Const"); sum_reduction_indices_op->set_name(sum_reduction_indices); (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -833,26 +831,26 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, sum_reduction_indices_tensor->add_int_val(0); sum_reduction_indices_tensor->add_int_val(1); - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(square_output); *square_op->add_input() = src_op.inputs[0]; (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* sum_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sum_op = tensorflow_graph->add_node(); sum_op->set_op("Sum"); sum_op->set_name(sum_output); *sum_op->add_input() = square_output; *sum_op->add_input() = sum_reduction_indices; (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* rsqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); rsqrt_op->set_op("Rsqrt"); rsqrt_op->set_name(rsqrt_output); *rsqrt_op->add_input() = sum_output; (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* mul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); mul_op->set_op("Mul"); mul_op->set_name(src_op.outputs[0]); *mul_op->add_input() = src_op.inputs[0]; @@ -863,7 +861,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, void ConvertLocalResponseNormalizationOperator( const LocalResponseNormalizationOperator& src_op, GraphDef* tensorflow_graph) { - auto* lrn_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node(); lrn_op->set_op("LRN"); lrn_op->set_name(src_op.outputs[0]); *lrn_op->add_input() = src_op.inputs[0]; @@ -875,7 +873,7 @@ void ConvertLocalResponseNormalizationOperator( void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, GraphDef* tensorflow_graph) { - auto* fakequant_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node(); fakequant_op->set_op("FakeQuantWithMinMaxArgs"); fakequant_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -890,7 +888,7 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, GraphDef* tensorflow_graph) { - auto* maxpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node(); maxpool_op->set_op("MaxPool"); maxpool_op->set_name(src_op.outputs[0]); *maxpool_op->add_input() = src_op.inputs[0]; @@ -918,7 +916,7 @@ void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, GraphDef* tensorflow_graph) { - auto* avgpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); avgpool_op->set_op("AvgPool"); avgpool_op->set_name(src_op.outputs[0]); *avgpool_op->add_input() = src_op.inputs[0]; @@ -947,7 +945,7 @@ void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, void ConvertConcatenationOperator(const Model& model, const ConcatenationOperator& src_op, GraphDef* tensorflow_graph) { - auto* dc_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* dc_op = tensorflow_graph->add_node(); dc_op->set_op("ConcatV2"); dc_op->set_name(src_op.outputs[0]); const string dummy_axis = src_op.outputs[0] + "/axis"; @@ -965,7 +963,7 @@ void ConvertConcatenationOperator(const Model& model, void ConvertTensorFlowReshapeOperator(const Model& model, const TensorFlowReshapeOperator& src_op, GraphDef* tensorflow_graph) { - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -987,7 +985,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, const string square_output = src_op.outputs[0] + "/square"; const string avgpool_output = src_op.outputs[0] + "/avgpool"; - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(square_output); *square_op->add_input() = src_op.inputs[0]; @@ -1002,7 +1000,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } - auto* avgpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); avgpool_op->set_op("AvgPool"); avgpool_op->set_name(avgpool_output); *avgpool_op->add_input() = square_output; @@ -1020,7 +1018,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, ksize.mutable_list()->add_i(src_op.kwidth); ksize.mutable_list()->add_i(1); - auto* sqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); sqrt_op->set_op("Sqrt"); sqrt_op->set_name(src_op.outputs[0]); *sqrt_op->add_input() = avgpool_output; @@ -1029,7 +1027,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, GraphDef* tensorflow_graph) { - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1039,7 +1037,7 @@ void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, GraphDef* tensorflow_graph) { - auto* sqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); sqrt_op->set_op("Sqrt"); sqrt_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1047,10 +1045,23 @@ void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertRsqrtOperator(const Model& model, + const TensorFlowRsqrtOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); + rsqrt_op->set_op("Rsqrt"); + rsqrt_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *rsqrt_op->add_input() = src_op.inputs[0]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*rsqrt_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertSplitOperator(const Model& model, const TensorFlowSplitOperator& src_op, GraphDef* tensorflow_graph) { - auto* split_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); split_op->set_op("Split"); split_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { @@ -1071,7 +1082,7 @@ void ConvertSplitOperator(const Model& model, void ConvertCastOperator(const Model& model, const CastOperator& src_op, GraphDef* tensorflow_graph) { - auto* cast_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* cast_op = tensorflow_graph->add_node(); cast_op->set_op("Cast"); cast_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1085,7 +1096,7 @@ void ConvertCastOperator(const Model& model, const CastOperator& src_op, void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, GraphDef* tensorflow_graph) { - auto* floor_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* floor_op = tensorflow_graph->add_node(); floor_op->set_op("Floor"); floor_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1095,7 +1106,7 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { - auto* gather_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); gather_op->set_op("Gather"); gather_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1103,13 +1114,14 @@ void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, *gather_op->add_input() = src_op.inputs[1]; (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*gather_op->mutable_attr())["Tparams"].set_type(params_type); } void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GraphDef* tensorflow_graph) { - auto* argmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node(); argmax_op->set_op("ArgMax"); argmax_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1126,7 +1138,7 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, void ConvertTransposeOperator(const Model& model, const TransposeOperator& src_op, GraphDef* tensorflow_graph) { - auto* transpose_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); transpose_op->set_op("Transpose"); transpose_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1141,7 +1153,7 @@ void ConvertTransposeOperator(const Model& model, void ConvertTensorFlowShapeOperator(const Model& model, const TensorFlowShapeOperator& src_op, GraphDef* tensorflow_graph) { - auto* shape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* shape_op = tensorflow_graph->add_node(); shape_op->set_op("Shape"); shape_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1154,7 +1166,7 @@ void ConvertTensorFlowShapeOperator(const Model& model, void ConvertRankOperator(const Model& model, const RankOperator& src_op, GraphDef* tensorflow_graph) { - auto* rank_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* rank_op = tensorflow_graph->add_node(); rank_op->set_op("Rank"); rank_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1165,7 +1177,7 @@ void ConvertRankOperator(const Model& model, const RankOperator& src_op, void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, GraphDef* tensorflow_graph) { - auto* range_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* range_op = tensorflow_graph->add_node(); range_op->set_op("Range"); range_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); @@ -1178,7 +1190,7 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, void ConvertStackOperator(const Model& model, const StackOperator& src_op, GraphDef* tensorflow_graph) { - auto* stack_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* stack_op = tensorflow_graph->add_node(); stack_op->set_op("Stack"); stack_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { @@ -1191,7 +1203,7 @@ void ConvertStackOperator(const Model& model, const StackOperator& src_op, void ConvertFillOperator(const Model& model, const FillOperator& src_op, GraphDef* tensorflow_graph) { - auto* fill_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* fill_op = tensorflow_graph->add_node(); fill_op->set_op("Fill"); fill_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1205,7 +1217,7 @@ void ConvertFillOperator(const Model& model, const FillOperator& src_op, void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, GraphDef* tensorflow_graph) { - auto* floor_div_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node(); floor_div_op->set_op("FloorDiv"); floor_div_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1218,7 +1230,7 @@ void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, void ConvertExpandDimsOperator(const Model& model, const ExpandDimsOperator& src_op, GraphDef* tensorflow_graph) { - auto* expand_dims_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node(); expand_dims_op->set_op("ExpandDims"); expand_dims_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1233,7 +1245,7 @@ void ConvertExpandDimsOperator(const Model& model, void ConvertResizeBilinearOperator(const Model& model, const ResizeBilinearOperator& src_op, GraphDef* tensorflow_graph) { - auto* resize_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* resize_op = tensorflow_graph->add_node(); resize_op->set_op("ResizeBilinear"); resize_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1283,7 +1295,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // works the same since the tensor has the same underlying data layout. const string axis_output = concat_output + "/axis"; CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph); - auto* concat_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* concat_op = tensorflow_graph->add_node(); concat_op->set_op("ConcatV2"); concat_op->set_name(concat_output); *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT]; @@ -1311,7 +1323,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Fully connected matrix multiply const string matmul_output = base + "MatMul"; - auto* matmul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); matmul_op->set_name(matmul_output); *matmul_op->add_input() = concat_output; @@ -1340,7 +1352,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Add biases string biasadd_output = base + "BiasAdd"; - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(biasadd_output); biasadd_op->add_input(matmul_output); @@ -1353,7 +1365,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // The dimension is the same as the concatenation dimension CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph); string split_output = base + "split"; - auto* split_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); split_op->set_op("Split"); split_op->set_name(split_output); *split_op->add_input() = split_dim_output; @@ -1363,21 +1375,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Activation functions and memory computations const string tanh_0_output = base + "Tanh"; - auto* tanh_0_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node(); tanh_0_op->set_op("Tanh"); tanh_0_op->set_name(tanh_0_output); *tanh_0_op->add_input() = split_output + ":1"; (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_1_output = base + "Sigmoid_1"; - auto* logistic_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node(); logistic_1_op->set_op("Sigmoid"); logistic_1_op->set_name(sigmoid_1_output); *logistic_1_op->add_input() = split_output; (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_1_output = base + "mul_1"; - auto* mul_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node(); mul_1_op->set_op("Mul"); mul_1_op->set_name(mul_1_output); *mul_1_op->add_input() = sigmoid_1_output; @@ -1385,21 +1397,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_0_output = base + "Sigmoid"; - auto* logistic_2_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node(); logistic_2_op->set_op("Sigmoid"); logistic_2_op->set_name(sigmoid_0_output); *logistic_2_op->add_input() = split_output + ":2"; (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_2_output = base + "Sigmoid_2"; - auto* logistic_3_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node(); logistic_3_op->set_op("Sigmoid"); logistic_3_op->set_name(sigmoid_2_output); *logistic_3_op->add_input() = split_output + ":3"; (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_0_output = base + "mul"; - auto* mul_0_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node(); mul_0_op->set_op("Mul"); mul_0_op->set_name(mul_0_output); *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT]; @@ -1407,7 +1419,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT); const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT]; - auto* add_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node(); add_1_op->set_op("Add"); add_1_op->set_name(add_1_output); *add_1_op->add_input() = mul_0_output; @@ -1415,14 +1427,14 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string tanh_1_output = base + "Tanh_1"; - auto* tanh_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node(); tanh_1_op->set_op("Tanh"); tanh_1_op->set_name(tanh_1_output); *tanh_1_op->add_input() = add_1_output; (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]; - auto* mul_2_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node(); mul_2_op->set_op("Mul"); mul_2_op->set_name(mul_2_output); *mul_2_op->add_input() = tanh_1_output; @@ -1433,14 +1445,15 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, void ConvertSpaceToBatchNDOperator(const Model& model, const SpaceToBatchNDOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("SpaceToBatchND"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32); @@ -1449,14 +1462,15 @@ void ConvertSpaceToBatchNDOperator(const Model& model, void ConvertBatchToSpaceNDOperator(const Model& model, const BatchToSpaceNDOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("BatchToSpaceND"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32); @@ -1464,18 +1478,19 @@ void ConvertBatchToSpaceNDOperator(const Model& model, void ConvertPadOperator(const Model& model, const PadOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Pad"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1494,7 +1509,7 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op, void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("PadV2"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1502,11 +1517,12 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1525,7 +1541,7 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, void CreateSliceInput(const string& input_name, const std::vector& values, GraphDef* tensorflow_graph) { - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(input_name); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1542,7 +1558,7 @@ void CreateSliceInput(const string& input_name, const std::vector& values, void ConvertStridedSliceOperator(const Model& model, const StridedSliceOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("StridedSlice"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 4); @@ -1551,7 +1567,8 @@ void ConvertStridedSliceOperator(const Model& model, *new_op->add_input() = src_op.inputs[2]; *new_op->add_input() = src_op.inputs[3]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Index"].set_type(DT_INT32); @@ -1569,7 +1586,7 @@ void ConvertStridedSliceOperator(const Model& model, void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Slice"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); @@ -1577,7 +1594,8 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Index"].set_type(DT_INT32); @@ -1588,14 +1606,15 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Mean"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); if (src_op.keep_dims) { @@ -1603,7 +1622,7 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, } // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1619,13 +1638,14 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Squeeze"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); *new_op->add_input() = src_op.inputs[0]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); if (!src_op.squeeze_dims.empty()) { @@ -1638,74 +1658,79 @@ void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, void ConvertSubOperator(const Model& model, const SubOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Sub"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMinimumOperator(const Model& model, const TensorFlowMinimumOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Minimum"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMaximumOperator(const Model& model, const TensorFlowMaximumOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Maximum"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Select"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; *sub_op->add_input() = src_op.inputs[2]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertTileOperator(const Model& model, const TensorFlowTileOperator& src_op, GraphDef* tensorflow_graph) { - auto* tile_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tile_op = tensorflow_graph->add_node(); tile_op->set_op("Tile"); tile_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *tile_op->add_input() = src_op.inputs[0]; *tile_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*tile_op->mutable_attr())["T"].set_type(data_type); - const auto multiples_data_type = + const tensorflow::DataType multiples_data_type = GetTensorFlowDataType(model, src_op.inputs[1]); (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type); } void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { - auto* topk_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* topk_op = tensorflow_graph->add_node(); topk_op->set_op("TOPKV2"); topk_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1718,12 +1743,13 @@ void ConvertRandomUniformOperator(const Model& model, const RandomUniformOperator& src_op, GraphDef* tensorflow_graph) { CHECK(tensorflow_graph != nullptr); - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("RandomUniform"); CHECK_EQ(src_op.inputs.size(), 1); new_op->set_name(src_op.outputs[0]); *new_op->add_input() = src_op.inputs[0]; - const auto shape_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType shape_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(shape_type); (*new_op->mutable_attr())["dtype"].set_type( GetTensorFlowDataType(src_op.dtype)); @@ -1734,13 +1760,14 @@ void ConvertRandomUniformOperator(const Model& model, void ConvertComparisonOperator(const Model& model, const Operator& src_op, const char* op_name, GraphDef* tensorflow_graph) { - auto* comparison_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node(); comparison_op->set_op(op_name); comparison_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *comparison_op->add_input() = src_op.inputs[0]; *comparison_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*comparison_op->mutable_attr())["T"].set_type(data_type); } @@ -1748,21 +1775,37 @@ void ConvertSparseToDenseOperator(const Model& model, const SparseToDenseOperator& src_op, const char* op_name, GraphDef* tensorflow_graph) { - auto* sparse_to_dense_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node(); sparse_to_dense_op->set_op(op_name); sparse_to_dense_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 4); for (int i = 0; i < 4; ++i) { *sparse_to_dense_op->add_input() = src_op.inputs[i]; } - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[3]); (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type); - const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType index_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type); (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b( src_op.validate_indices); } +void ConvertPowOperator(const Model& model, const PowOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* pow_op = tensorflow_graph->add_node(); + pow_op->set_op(op_name); + pow_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *pow_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*pow_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1843,20 +1886,24 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertConcatenationOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowReshape) { + } else if (src_op.type == OperatorType::kReshape) { ConvertTensorFlowReshapeOperator( model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kL2Pool) { ConvertL2PoolOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSquare) { + } else if (src_op.type == OperatorType::kSquare) { ConvertSquareOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSqrt) { + } else if (src_op.type == OperatorType::kSqrt) { ConvertSqrtOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSplit) { + } else if (src_op.type == OperatorType::kRsqrt) { + ConvertRsqrtOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSplit) { ConvertSplitOperator(model, static_cast(src_op), tensorflow_graph); @@ -1900,11 +1947,11 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kSub) { ConvertSubOperator(model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMinimum) { + } else if (src_op.type == OperatorType::kMinimum) { ConvertTensorFlowMinimumOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMaximum) { + } else if (src_op.type == OperatorType::kMaximum) { ConvertTensorFlowMaximumOperator( model, static_cast(src_op), tensorflow_graph); @@ -1923,7 +1970,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kTranspose) { ConvertTransposeOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowShape) { + } else if (src_op.type == OperatorType::kShape) { ConvertTensorFlowShapeOperator( model, static_cast(src_op), tensorflow_graph); @@ -1954,25 +2001,28 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowEqual) { + } else if (src_op.type == OperatorType::kEqual) { ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowNotEqual) { + } else if (src_op.type == OperatorType::kNotEqual) { ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreater) { + } else if (src_op.type == OperatorType::kGreater) { ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { + } else if (src_op.type == OperatorType::kGreaterEqual) { ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLess) { + } else if (src_op.type == OperatorType::kLess) { ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLessEqual) { + } else if (src_op.type == OperatorType::kLessEqual) { ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph); } else if (src_op.type == OperatorType::kSelect) { ConvertSelectOperator(model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowTile) { + } else if (src_op.type == OperatorType::kTile) { ConvertTileOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kPow) { + ConvertPowOperator(model, static_cast(src_op), "Pow", + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -1980,7 +2030,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, void AddPlaceholder(const string& name, ArrayDataType type, GraphDef* tensorflow_graph) { - auto* placeholder = tensorflow_graph->add_node(); + tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); placeholder->set_op("Placeholder"); switch (type) { case ArrayDataType::kBool: @@ -2009,7 +2059,7 @@ void AddPlaceholder(const string& name, ArrayDataType type, void AddPlaceholderForRNNState(const Model& model, const string& name, int size, GraphDef* tensorflow_graph) { - auto* placeholder = tensorflow_graph->add_node(); + tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); placeholder->set_op("Placeholder"); placeholder->set_name(name); (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 8e93f02ef109f7bccd07ce54baff3d0bb4ae50c7..18b7848db86e553ec645fa87298420012b5f753f 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -9,57 +9,56 @@ complemented by the following documents: Table of contents: -* [Convert a TensorFlow SavedModel to TensorFlow Lite](#savedmodel) -* [Convert a TensorFlow GraphDef to TensorFlow Lite for float - inference](#graphdef-float) +* [Command-line tools](#tools) + * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) +* [Basic examples](#basic) + * [Convert a TensorFlow GraphDef](#graphdef) + * [Convert a TensorFlow SavedModel](#savedmodel) + * [Convert a tf.keras model](#keras) * [Quantization](#quantization) - * [Convert a TensorFlow GraphDef to TensorFlow Lite for quantized - inference](#graphdef-quant) + * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant) * [Use "dummy-quantization" to try out quantized inference on a float graph](#dummy-quant) * [Specifying input and output arrays](#specifying-input-and-output-arrays) - * [Multiple output arrays](#multiple-output-arrays) * [Multiple input arrays](#multiple-input-arrays) + * [Multiple output arrays](#multiple-output-arrays) * [Specifying subgraphs](#specifying-subgraphs) -* [Other conversions supported by TOCO](#other-conversions) - * [Optimize a TensorFlow GraphDef](#optimize-graphdef) - * [Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef - format](#to-graphdef) -* [Logging](#logging) - * [Graph "video" logging](#graph-video-logging) * [Graph visualizations](#graph-visualizations) * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot) * [Using --dump_graphviz](#using-dump-graphviz) + * [Graph "video" logging](#graph-video-logging) * [Legend for the graph visualizations](#graphviz-legend) -## Convert a TensorFlow SavedModel to TensorFlow Lite +## Command-line tools -The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite -FlatBuffer to perform floating-point inference. +There are two approaches to running TOCO via command line. -``` -bazel run --config=opt \ - third_party/tensorflow/contrib/lite/toco:toco -- \ - --savedmodel_directory=/tmp/saved_model \ - --output_file=/tmp/foo.tflite -``` +* `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool + `tflite_convert` will be installed as part of the Python package. All of the + examples below use `tflite_convert` for simplicity. + * Example: `tflite --output_file=...` +* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow + repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository) + and use `bazel`. This is the recommended approach for converting models that + utilize new features that were not supported by TOCO in TensorFlow 1.9. + * Example: `bazel run + //tensorflow/contrib/lite/python:tflite_convert -- + --output_file=...` -[SavedModel](https://www.tensorflow.org/programmers_guide/saved_model#using_savedmodel_with_estimators) -has fewer required flags than frozen graphs (described [below](#graphdef-float)) -due to access to additional data contained within the SavedModel. The values for -`--input_arrays` and `--output_arrays` are an aggregated, alphabetized list of -the inputs and outputs in the -[SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within the -[MetaGraphDef](https://www.tensorflow.org/programmers_guide/saved_model#apis_to_build_and_load_a_savedmodel) -specified by `--savedmodel_tagset`. The value for `input_shapes` is -automatically determined from the MetaGraphDef whenever possible. The default -value for `--inference_type` for SavedModels is `FLOAT`. +### Converting models prior to TensorFlow 1.9. -There is currently no support for MetaGraphDefs without a SignatureDef or for -MetaGraphDefs that use the [`assets/` -directory](https://www.tensorflow.org/programmers_guide/saved_model#structure_of_a_savedmodel_directory). +The recommended approach for using TOCO prior to TensorFlow 1.9 is the [Python +API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the +`toco` command line tool was available in TensorFlow 1.7. Enter `toco --help` in +Terminal for additional details on the command-line flags available. There were +no command line tools in TensorFlow 1.8. + +## Basic examples + +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. -## Convert a TensorFlow GraphDef to TensorFlow Lite for float inference +### Convert a TensorFlow GraphDef The follow example converts a basic TensorFlow GraphDef (frozen by [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)) @@ -69,19 +68,54 @@ graphs contain the variables stored in Checkpoint files as Const ops. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.tflite \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 +``` + +The value for `input_shapes` is automatically determined whenever possible. + +### Convert a TensorFlow SavedModel + +The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite +FlatBuffer to perform floating-point inference. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --saved_model_dir=/tmp/saved_model +``` + +[SavedModel](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators) +has fewer required flags than frozen graphs due to access to additional data +contained within the SavedModel. The values for `--input_arrays` and +`--output_arrays` are an aggregated, alphabetized list of the inputs and outputs +in the [SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within +the +[MetaGraphDef](https://www.tensorflow.org/guide/saved_model#apis_to_build_and_load_a_savedmodel) +specified by `--saved_model_tag_set`. As with the GraphDef, the value for +`input_shapes` is automatically determined whenever possible. + +There is currently no support for MetaGraphDefs without a SignatureDef or for +MetaGraphDefs that use the [`assets/` +directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory). + +### Convert a tf.Keras model + +The following example converts a `tf.keras` model into a TensorFlow Lite +Flatbuffer. The `tf.keras` file must contain both the model and the weights. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --keras_model_file=/tmp/keras_model.h5 ``` ## Quantization -### Convert a TensorFlow GraphDef to TensorFlow Lite for quantized inference +### Convert a TensorFlow GraphDef for quantized inference TOCO is compatible with fixed point quantization models described [here](https://www.tensorflow.org/performance/quantization). These are float @@ -95,18 +129,14 @@ The following command generates a quantized TensorFlow Lite FlatBuffer from a "quantized" TensorFlow GraphDef. ``` -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/some_quantized_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ + --graph_def_file=/tmp/some_quantized_graph.pb \ --inference_type=QUANTIZED_UINT8 \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --mean_value=128 \ - --std_value=127 + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ + --mean_values=128 \ + --std_dev_values=127 ``` ### Use \"dummy-quantization\" to try out quantized inference on a float graph @@ -124,45 +154,20 @@ a reasonable guess is that most activation ranges should be contained in [0, 6]. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.cc \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --inference_type=QUANTIZED_UINT8 \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ --default_ranges_min=0 \ --default_ranges_max=6 \ - --mean_value=127.5 \ - --std_value=127.5 + --mean_values=128 \ + --std_dev_values=127 ``` ## Specifying input and output arrays -### Multiple output arrays - -The flag `output_arrays` takes in a comma-separated list of output arrays as -seen in the example below. This is useful for models or subgraphs with multiple -outputs. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ - --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,224,224,3 \ - --input_array=input \ - --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu -``` - ### Multiple input arrays The flag `input_arrays` takes in a comma-separated list of input arrays as seen @@ -172,21 +177,33 @@ inputs. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ - --output_array=InceptionV1/Logits/Predictions/Reshape_1 + --output_arrays=InceptionV1/Logits/Predictions/Reshape_1 ``` Note that `input_shapes` is provided as a colon-separated list. Each input shape corresponds to the input array at the same position in the respective list. +### Multiple output arrays + +The flag `output_arrays` takes in a comma-separated list of output arrays as +seen in the example below. This is useful for models or subgraphs with multiple +outputs. + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ + | tar xzv -C /tmp +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ + --output_file=/tmp/foo.tflite \ + --input_arrays=input \ + --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu +``` + ### Specifying subgraphs Any array in the input file can be specified as an input or output array in @@ -201,115 +218,57 @@ GraphDef. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ --output_file=/tmp/foo.pb \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TENSORFLOW_GRAPHDEF \ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ - --output_array=InceptionV1/InceptionV1/Mixed_3b/concat_v2 + --output_arrays=InceptionV1/InceptionV1/Mixed_3b/concat_v2 ``` -Note that the final representation of an on-device inference workload (say, in -TensorFlow Lite FlatBuffers format) tends to have coarser granularity than the -very fine granularity of the TensorFlow GraphDef representation. For example, -while a fully-connected layer is typically represented as at least four separate -ops in TensorFlow GraphDef (Reshape, MatMul, BiasAdd, Relu...), it is typically -represented as a single "fused" op (FullyConnected) in the converter's optimized -representation and in the final on-device representation (e.g. in TensorFlow -Lite FlatBuffer format). As the level of granularity gets coarser, some +Note that the final representation in TensorFlow Lite FlatBuffers tends to have +coarser granularity than the very fine granularity of the TensorFlow GraphDef +representation. For example, while a fully-connected layer is typically +represented as at least four separate ops in TensorFlow GraphDef (Reshape, +MatMul, BiasAdd, Relu...), it is typically represented as a single "fused" op +(FullyConnected) in the converter's optimized representation and in the final +on-device representation. As the level of granularity gets coarser, some intermediate arrays (say, the array between the MatMul and the BiasAdd in the -TensorFlow GraphDef) are dropped. When specifying intermediate arrays as -`--input_arrays` / `--output_arrays`, it is desirable (and often required) to -specify arrays that are meant to survive in the final form of the graph, after -fusing. These are typically the outputs of activation functions (since -everything in each layer until the activation function tends to get fused). - -## Other conversions supported by TOCO - -The converter accepts both TENSORFLOW_GRAPHDEF and TFLITE file formats as both -`--input_format` and `--output_format`. This means that conversion to and from -any supported format is possible. - -### Optimize a TensorFlow GraphDef - -Same-format "conversions" can be used to optimize and simplify a graph or be -used to [get a subgraph](#specifying-subgraphs) of a graph. The flag -`--inference_type` is not required because TensorFlow graphs, including those -containing the -[`FakeQuant*`](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization) -ops are always float graphs. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.pb \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TENSORFLOW_GRAPHDEF \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 -``` +TensorFlow GraphDef) are dropped. -### Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef format - -The converter supports file format conversions from TensorFlow Lite, back into -TensorFlow GraphDef format. - -``` -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/foo.tflite \ - --output_file=/tmp/foo.pb \ - --input_format=TFLITE \ - --output_format=TENSORFLOW_GRAPHDEF \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 -``` +When specifying intermediate arrays as `--input_arrays` and `--output_arrays`, +it is desirable (and often required) to specify arrays that are meant to survive +in the final form of the graph, after fusing. These are typically the outputs of +activation functions (since everything in each layer until the activation +function tends to get fused). ## Logging -### Graph "video" logging - -When `--dump_graphviz=` is used (see the section on [graph -visualizations](#graph-visualizations)), one may additionally pass -`--dump_graphviz_video`, which causes a graph visualization to be dumped after -each individual graph transformation. This results in thousands of files. -Typically, one would then bisect into these files to understand when a given -change was introduced in the graph. ## Graph visualizations TOCO can export a graph to the GraphViz Dot format for easy visualization via -either the `--output_format` flag or the `--dump_graphviz` flag. The subsections -below outline the use cases for each. +either the `--output_format` flag or the `--dump_graphviz_dir` flag. The +subsections below outline the use cases for each. ### Using `--output_format=GRAPHVIZ_DOT` The first way to get a graphviz rendering is to pass `GRAPHVIZ_DOT` into `--output_format`. This results in a plausible visualization of the graph. This -reduces the requirements that normally exist during conversion between other -input and output formats. For example, this may be useful if conversion from -TENSORFLOW_GRAPHDEF to TFLITE is failing. +reduces the requirements that exist during conversion between other input and +output formats. This may be useful if conversion from TENSORFLOW_GRAPHDEF to +TFLITE is failing. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --output_file=/tmp/foo.dot \ - --input_format=TENSORFLOW_GRAPHDEF \ --output_format=GRAPHVIZ_DOT \ --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 ``` The resulting `.dot` file can be rendered into a PDF as follows: @@ -330,49 +289,35 @@ Example PDF files are viewable online in the next section. ### Using `--dump_graphviz` -The second way to get a graphviz rendering is to pass the `--dump_graphviz=` +The second way to get a graphviz rendering is to pass the `--dump_graphviz_dir` flag, specifying a destination directory to dump GraphViz rendering to. Unlike -the previous approach, this one allows you to keep your real command-line (with -your real `--output_format` and other flags) unchanged, just appending a -`--dump_graphviz=` flag to it. This provides a visualization of the actual graph -during a specific conversion process. +the previous approach, this one retains the original output format. This +provides a visualization of the actual graph resulting from a specific +conversion process. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --dump_graphviz=/tmp -``` - -This generates a few files in the destination directory, here `/tmp`. The two -most important files are: - -``` -/tmp/toco_AT_IMPORT.dot -/tmp/toco_AFTER_TRANSFORMATIONS.dot + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ + --dump_graphviz_dir=/tmp ``` -`toco_AT_IMPORT.dot` represents the graph as it was imported from -`--input_file`, before any transformation was applied to it (besides some -transformations that are applied immediately while importing). This tends to be -a complex visualization with limited information, but is useful especially in -situations where a conversion command fails (this file is generated even if the -conversion subsequently fails). +This generates a few files in the destination directory. The two most important +files are `toco_AT_IMPORT.dot` and `/tmp/toco_AFTER_TRANSFORMATIONS.dot`. +`toco_AT_IMPORT.dot` represents the original graph containing only the +transformations done at import time. This tends to be a complex visualization +with limited information about each node. It is useful in situations where a +conversion command fails. `toco_AFTER_TRANSFORMATIONS.dot` represents the graph after all transformations -were applied to it, just before it was exported to the `--output_file`. -Typically, this is a much smaller graph with more information about each node. +were applied to it, just before it is exported. Typically, this is a much +smaller graph with more information about each node. -Again, these can be rendered to PDFs: +As before, these can be rendered to PDFs: ``` dot -Tpdf -O /tmp/toco_*.dot @@ -383,6 +328,14 @@ Sample output files can be seen here: * [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf) * [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf). +### Graph "video" logging + +When `--dump_graphviz_dir` is used, one may additionally pass +`--dump_graphviz_video`. This causes a graph visualization to be dumped after +each individual graph transformation, resulting in thousands of files. +Typically, one would then bisect into these files to understand when a given +change was introduced in the graph. + ### Legend for the graph visualizations * Operators are red square boxes with the following hues of red: diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 8085ae07489816c38677ff792e7ac71f1a75fa71..decc8a45a40ffba2a27320ce8391b1916391d744 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -1,7 +1,8 @@ # TensorFlow Lite Optimizing Converter command-line glossary -This page is complete reference of command-line flags. It is complemented by the -following other documents: +This page is complete reference of command-line flags used by TOCO's command +line starting from TensorFlow 1.9 up until the most recent build of TensorFlow. +It is complemented by the following other documents: * [README](../README.md) * [Command-line examples](cmdline_examples.md) @@ -16,116 +17,81 @@ Table of contents: ## High-level flags -The following high level flags specify the location of the input and output +The following high level flags specify the details of the input and output files. The flag `--output_file` is always required. Additionally, either -`--input_file` or `--savedmodel_directory` is required. - -* `--savedmodel_directory`. Type: string. Specifies the full path to the - directory containing the SavedModel. -* `--savedmodel_tagset`. Type: string. Default: +`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required. + +* `--output_file`. Type: string. Specifies the full path of the output file. +* `--graph_def_file`. Type: string. Specifies the full path of the input + GraphDef file frozen using + [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). +* `--saved_model_dir`. Type: string. Specifies the full path to the directory + containing the SavedModel. +* `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file + containing the tf.keras model. +* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of + the output file. Allowed values: + * `TFLITE`: TensorFlow Lite FlatBuffer format. + * `GRAPHVIZ_DOT`: GraphViz `.dot` format containg a visualization of the + graph after graph transformations. + * Note that passing `GRAPHVIZ_DOT` to `--output_format` leads to loss + of TFLite specific transformations. Therefore, the resulting + visualization may not reflect the final set of graph + transformations. To get a final visualization with all graph + transformations use `--dump_graphviz` instead. + +The following flags specify optional parameters when using SavedModels. + +* `--saved_model_tag_set`. Type: string. Default: [kSavedModelTagServe](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h). Specifies a comma-separated set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be specified. -* `--input_file`. Type: string. Specifies the path of the input file. This may - be either an absolute or a relative path. -* `--output_file`. Type: string. Specifies the path of the output file. - -The following high level flags specify the types of the input and output files: - -* `--input_format`. Type: string. Default: `TENSORFLOW_GRAPHDEF`. Specifies - the format of the input file. Allowed values: - * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Both - binary and text proto formats are allowed. - * `TFLITE` — The TensorFlow Lite FlatBuffers format. -* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of - the output file. Allowed values: - * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Always - produces a file in binary (not text) proto format. - * `TFLITE` — The TensorFlow Lite FlatBuffers format. - * Whether a float or quantized TensorFlow Lite file will be produced - depends on the `--inference_type` flag. - * `GRAPHVIZ_DOT` — The GraphViz `.dot` format. This asks the - converter to generate a reasonable graphical representation of the graph - after simplification by a generic set of transformation. - * A typical `dot` command line to view the resulting graph might look - like: `dot -Tpdf -O file.dot`. - * Note that since passing this `--output_format` means losing the - information of which output format you actually care about, and - since the converter's transformations depend on the specific output - format, the resulting visualization may not fully reflect what you - would get on the actual output format that you are using. To avoid - that concern, and generally to get a visualization of exactly what - you get in your actual output format as opposed to just a merely - plausible visualization of a model, consider using `--dump_graphviz` - instead and keeping your true `--output_format`. +* `--saved_model_signature_key`. Type: string. Default: + [DEFAULT_SERVING_SIGNATURE_DEF_KEY](https://www.tensorflow.org/api_docs/python/tf/saved_model/signature_constants). + Specifies the key identifying the SignatureDef containing inputs and + outputs. ## Model flags *Model flags* provide additional information about the model stored in the input file. -* `--output_array`. Type: string. Specifies a single array as the output - activations. Incompatible with `--output_arrays`. -* `--output_arrays`. Type: comma-separated list of strings. Specifies a list - of arrays as the output activations, for models with multiple outputs. - Incompatible with `--output_array`. -* `--input_array`. Type: string. Specifies a single array as the input - activations. Incompatible with `--input_arrays`. -* `--input_arrays`. Type: comma-separated list of strings. Specifies a list of - arrays as the input activations, for models with multiple inputs. - Incompatible with `--input_array`. -* `--batch_size`. Type: integer. Default: 1. Specifies the batch size for the - model. Replaces the first dimension of an input size array if undefined. Use - only with SavedModels when neither `--input_shape` nor `input_shapes` flags - are specified. Incompatible with GraphDefs. - -When `--input_array` is used, the following flags are available to provide -additional information about the single input array: - -* `--input_shape`. Type: comma-separated list of integers. Specifies the shape - of the input array, in TensorFlow convention: starting with the outer-most - dimension (the dimension corresponding to the largest offset stride in the - array layout), ending with the inner-most dimension (the dimension along - which array entries are typically laid out contiguously in memory). - * For example, a typical vision model might pass - `--input_shape=1,60,80,3`, meaning a batch size of 1 (no batching), an - input image height of 60, an input image width of 80, and an input image - depth of 3, for the typical case where the input image is a RGB bitmap - (3 channels, depth=3) stored by horizontal scanlines (so 'width' is the - next innermost dimension after 'depth'). -* `--mean_value` and `--std_value`. Type: floating-point. The decimal point - character is always the dot (`.`) regardless of the locale. These specify - the (de-)quantization parameters of the input array, when it is quantized. - * The meaning of mean_value and std_value is as follows: each quantized - value in the quantized input array will be interpreted as a mathematical - real number (i.e. as an input activation value) according to the - following formula: +* `--input_arrays`. Type: comma-separated list of strings. Specifies the list + of names of input activation tensors. +* `--output_arrays`. Type: comma-separated list of strings. Specifies the list + of names of output activation tensors. + +The following flags define properties of the input tensors. Each item in the +`--input_arrays` flag should correspond to each item in the following flags +based on index. + +* `--input_shapes`. Type: colon-separated list of comma-separated lists of + integers. Each comma-separated list of integers gives the shape of one of + the input arrays specified in [TensorFlow + convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape). + * Example: `--input_shapes=1,60,80,3` for a typical vision model means a + batch size of 1, an input image height of 60, an input image width of + 80, and an input image depth of 3 (representing RGB channels). + * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means "foo" + has a shape of [2, 3] and "bar" has a shape of [4, 5, 6]. +* `--std_dev_values`, `--mean_values`. Type: comma-separated list of integers. + These specify the (de-)quantization parameters of the input array, when it + is quantized. + * The meaning of `mean_values` and `std_dev_values` is as follows: each + quantized value in the quantized input array will be interpreted as a + mathematical real number (i.e. as an input activation value) according + to the following formula: * `real_value = (quantized_input_value - mean_value) / std_value`. * When performing float inference (`--inference_type=FLOAT`) on a quantized input, the quantized input would be immediately dequantized by the inference code according to the above formula, before proceeding with float inference. * When performing quantized inference - (`--inference_type=QUANTIZED_UINT8`), no dequantization is ever to be - performed by the inference code; however, the quantization parameters of - all arrays, including those of the input arrays as specified by - mean_value and std_value, all participate in the determination of the - fixed-point multipliers used in the quantized inference code. - -When `--input_arrays` is used, the following flags are available to provide -additional information about the multiple input arrays: - -* `--input_shapes`. Type: colon-separated list of comma-separated lists of - integers. Each comma-separated list of integer gives the shape of one of the - input arrays specified in `--input_arrays`, in the same order. See - `--input_shape` for details. - * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means that - there are two input arrays. The first one, "foo", has shape [2,3]. The - second one, "bar", has shape [4,5,6]. -* `--mean_values`, `--std_values`. Type: comma-separated lists of - floating-point numbers. Each number gives the corresponding value for one of - the input arrays specified in `--input_arrays`, in the same order. See - `--mean_value`, `--std_value` for details. + (`--inference_type=QUANTIZED_UINT8`), no dequantization is performed by + the inference code. However, the quantization parameters of all arrays, + including those of the input arrays as specified by `mean_value` and + `std_dev_value`, determine the fixed-point multipliers used in the + quantized inference code. ## Transformation flags @@ -133,21 +99,13 @@ additional information about the multiple input arrays: the graph, i.e. they specify requested properties that the output file should have. -* `--inference_type`. Type: string. Sets the type of real-number arrays in the - output file, that is, controls the representation (quantization) of real - numbers in the output file, except for input arrays, which are controlled by - `--inference_input_type`. - - This flag only impacts real-number arrays. By "real-number" we mean float - arrays, and quantized arrays. This excludes plain integer arrays, strings - arrays, and every other data type. +* `--inference_type`. Type: string. Default: `FLOAT`. Data type of all + real-number arrays in the output file except for input arrays (defined by + `--inference_input_type`). Must be `{FLOAT, QUANTIZED_UINT8}`. - For real-number arrays, the impact of this flag is to allow the output file - to choose a different real-numbers representation (quantization) from what - the input file used. For any other types of arrays, changing the data type - would not make sense. - - Specifically: + This flag only impacts real-number arrays including float and quantized + arrays. This excludes all other data types including plain integer arrays + and string arrays. Specifically: * If `FLOAT`, then real-numbers arrays will be of type float in the output file. If they were quantized in the input file, then they get @@ -155,66 +113,54 @@ have. * If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as uint8 in the output file. If they were float in the input file, then they get quantized. - * If not set, then all real-numbers arrays retain the same type in the - output file as they have in the input file. - -* `--inference_input_type`. Type: string. Similar to inference_type, but - allows to control specifically the quantization of input arrays, separately - from other arrays. - - If not set, then the value of `--inference_type` is implicitly used, i.e. by - default input arrays are quantized like other arrays. - - Like `--inference_type`, this only affects real-number arrays. By - "real-number" we mean float arrays, and quantized arrays. This excludes - plain integer arrays, strings arrays, and every other data type. - - The typical use for this flag is for vision models taking a bitmap as input, - typically with uint8 channels, yet still requiring floating-point inference. - For such image models, the uint8 input is quantized, i.e. the uint8 values - are interpreted as real numbers, and the quantization parameters used for - such input arrays are their `mean_value`, `std_value` parameters. - -* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. The - decimal point character is always the dot (`.`) regardless of the locale. - These flags enable what is called "dummy quantization". If defined, their - effect is to define fallback (min, max) range values for all arrays that do - not have a properly specified (min, max) range in the input file, thus - allowing to proceed with quantization of non-quantized or - incorrectly-quantized input files. This enables easy performance prototyping - ("how fast would my model run if I quantized it?") but should never be used - in production as the resulting quantized arithmetic is inaccurate. - -* `--drop_fake_quant`. Type: boolean. Default: false. Causes fake-quantization - nodes to be dropped from the graph. This may be used to recover a plain - float graph from a fake-quantized graph. - -* `--reorder_across_fake_quant`. Type: boolean. Default: false. Normally, - fake-quantization nodes must be strict boundaries for graph transformations, - in order to ensure that quantized inference has the exact same arithmetic - behavior as quantized training --- which is the whole point of quantized - training and of FakeQuant nodes in the first place. However, that entails - subtle requirements on where exactly FakeQuant nodes must be placed in the - graph. Some quantized graphs have FakeQuant nodes at unexpected locations, - that prevent graph transformations that are necessary in order to generate a - well-formed quantized representation of these graphs. Such graphs should be - fixed, but as a temporary work-around, setting this - reorder_across_fake_quant flag allows the converter to perform necessary - graph transformations on them, at the cost of no longer faithfully matching - inference and training arithmetic. - -* `--quantize_weights`. Type: boolean. Default: false. Store weights as - quantized weights followed by dequantize operations. Computation is still - done in float, but reduces model size (at the cost of accuracy and latency). + +* `--inference_input_type`. Type: string. Data type of a real-number input + array in the output file. By default the `--inference_type` is used as type + of all of the input arrays. Flag is primarily intended for generating a + float-point graph with a quantized input array. A Dequantized operator is + added immediately after the input array. Must be `{FLOAT, QUANTIZED_UINT8}`. + + The flag is typically used for vision models taking a bitmap as input but + requiring floating-point inference. For such image models, the uint8 input + is quantized and the quantization parameters used for such input arrays are + their `mean_value` and `std_dev_value` parameters. + +* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. + Default value for the (min, max) range values used for all arrays without a + specified range. Allows user to proceed with quantization of non-quantized + or incorrectly-quantized input files. These flags produce models with low + accuracy. They are intended for easy experimentation with quantization via + "dummy quantization". + +* `--drop_control_dependency`. Type: boolean. Default: True. Indicates whether + to drop control dependencies silently. This is due to TensorFlow Lite not + supporting control dependencies. + +* `--reorder_across_fake_quant`. Type: boolean. Default: False. Indicates + whether to reorder FakeQuant nodes in unexpected locations. Used when the + location of the FakeQuant nodes is preventing graph transformations + necessary to convert the graph. Results in a graph that differs from the + quantized training graph, potentially causing differing arithmetic behavior. + +* `--allow_custom_ops`. Type: string. Default: False. Indicates whether to + allow custom operations. When false, any unknown operation is an error. When + true, custom ops are created for any op that is unknown. The developer will + need to provide these to the TensorFlow Lite runtime with a custom resolver. + +* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to + store weights as quantized weights followed by dequantize operations. + Computation is still done in float, but reduces model size (at the cost of + accuracy and latency). ## Logging flags -The following flags allow to generate graph visualizations of the actual graph -at various points during transformations: +The following flags generate graph visualizations of the graph as +[GraphViz](https://www.graphviz.org/) `.dot` files at various points during +graph transformations: -* `--dump_graphviz=/path` enables dumping of the graphs at various stages of - processing as GraphViz `.dot` files. Generally preferred over - `--output_format=GRAPHVIZ_DOT` as this allows you to keep your actually - relevant `--output_format`. -* `--dump_graphviz_video` enables dumping of the graph after every single - graph transformation (for debugging purposes). +* `--dump_graphviz_dir`. Type: string. Specifies the full path of the + directory to output GraphViz `.dot` files. Outputs the graph immediately + after reading in the graph and after all of the transformations have been + completed. +* `--dump_graphviz_video`. Type: boolean. Outputs GraphViz after every graph + transformation. Requires `--dump_graphviz_dir` to be specified. diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index a7841a685528fb18bb08f1943278339a2daec16a..3799eac0a1181afe3b63d2f8651745c2ec61f5e0 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -15,11 +15,15 @@ Table of contents: * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) * [Exporting a GraphDef from file](#basic-graphdef-file) * [Exporting a SavedModel](#basic-savedmodel) + * [Exporting a tf.keras File](#basic-keras-file) * [Complex examples](#complex) * [Exporting a quantized GraphDef](#complex-quant) * [TensorFlow Lite Python interpreter](#interpreter) * [Using the interpreter from a model file](#interpreter-file) * [Using the interpreter from model data](#interpreter-data) +* [Additional instructions](#additional-instructions) + * [Build from source code](#latest-package) + * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) ## High-level overview @@ -31,15 +35,17 @@ designing a model that can be targeted to devices with mobile. ## API -The API for converting TensorFlow models to TensorFlow Lite is -`tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is +The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9 +is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is `tf.contrib.lite.Interpreter`. `TocoConverter` provides class methods based on the original format of the model. `TocoConverter.from_session()` is available for GraphDefs. -`TocoConverter.from_saved_model()` is available for SavedModels. Example usages -for simple float-point models are shown in [Basic Examples](#basic). Examples -usages for more complex models is shown in [Complex Examples](#complex). +`TocoConverter.from_saved_model()` is available for SavedModels. +`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files. +Example usages for simple float-point models are shown in [Basic +Examples](#basic). Examples usages for more complex models is shown in [Complex +Examples](#complex). **NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python interpreter when the conversion fails. This will be remedied as soon as @@ -111,6 +117,51 @@ For more complex SavedModels, the optional parameters that can be passed into `output_arrays`, `tag_set` and `signature_key`. Details of each parameter are available by running `help(tf.contrib.lite.TocoConverter)`. +### Exporting a tf.keras File + +The following example shows how to convert a `tf.keras` model into a TensorFlow +Lite FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5") +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +The `tf.keras` file must contain both the model and the weights. A comprehensive +example including model construction can be seen below. + +```python +import numpy as np +import tensorflow as tf + +# Generate tf.keras model. +model = tf.keras.models.Sequential() +model.add(tf.keras.layers.Dense(2, input_shape=(3,))) +model.add(tf.keras.layers.RepeatVector(3)) +model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3))) +model.compile(loss=tf.keras.losses.MSE, + optimizer=tf.keras.optimizers.RMSprop(lr=0.0001), + metrics=[tf.keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + +x = np.random.random((1, 3)) +y = np.random.random((1, 3, 3)) +model.train_on_batch(x, y) +model.predict(x) + +# Save tf.keras model in HDF5 format. +keras_file = "keras_model.h5" +tf.keras.models.save_model(model, keras_file) + +# Convert to TensorFlow Lite model. +converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + ## Complex examples For models where the default value of the attributes is not sufficient, the @@ -200,3 +251,18 @@ with tf.Session() as sess: interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model) interpreter.allocate_tensors() ``` + +## Additional instructions + +### Build from source code + +In order to run the latest version of the TOCO Python API, clone the TensorFlow +repository, configure the installation, and build and install the pip package. +Detailed instructions are available +[here](https://www.tensorflow.org/install/install_sources). + +### Converting models prior to TensorFlow 1.9. + +To use TOCO in TensorFlow 1.7 and TensorFlow 1.8, use the `toco_convert` +function. Run `help(tf.contrib.lite.toco_convert)` to get details about accepted +parameters. diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg index a47c088991299159be39bc490149720dae43eb53..262e13a591b998c4f38f0a9f44a5b385f612df90 100644 --- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg +++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index 0fffab574ddd8ad75ec07ae4442f363a36ed289e..1ea83abf8eb1b49f649e81def29857094cd0c2d7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -38,6 +38,16 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { // Depthwise conv does not support dilation return false; } + auto& input_array = model->GetArray(conv_op->inputs[0]); + if (!input_array.has_shape()) { + // Shapes not propagated yet + return false; + } + if (input_array.shape().dims(3) != 1) { + // Not a pure convolution: Conv does accumulation across the depth + // dimension. + return false; + } auto& weights_array = model->GetArray(conv_op->inputs[1]); if (!weights_array.buffer) { // Yield until the weights are resolved as a constant array. @@ -46,11 +56,6 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { if (weights_array.data_type != ArrayDataType::kFloat) { return false; } - if (weights_array.shape().dims(3) != 1) { - // Not a pure convolution: Conv does accumulation across the depth - // dimension. - return false; - } // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. AddMessageF( "%s is purely convolutional (input/weights depth is 1), replacing it by " diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc index 5ab399206ba5376e2ff7c5c7028a1ea3e9b92a03..b689be07926ecd9be4cc317735dc88eb90950e13 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -23,7 +23,7 @@ namespace toco { bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { auto tile_it = model->operators.begin() + op_index; - if (tile_it->get()->type != OperatorType::kTensorFlowTile) { + if (tile_it->get()->type != OperatorType::kTile) { return false; } auto* tile_op = static_cast(tile_it->get()); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index 498c864bde6d656c8318e981204cb42cb3a4d03f..2c7ffe488477ef1a544dfe6f36a6e0d1ac40aa96 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -111,7 +111,7 @@ bool DequantizeArray(const string& array_name, auto* op_outputting_array = GetOpWithOutput(*model, array_name); if (op_outputting_array) { - if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { + if (op_outputting_array->type == OperatorType::kReshape) { return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc index 708ecf6e0a96811ab274fbb25f748f562cd3afad..e80ed036b311cfc586c40ece410ef6a6432a0cd9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -26,17 +26,38 @@ namespace toco { namespace { +int GetOutputDepthFromWeights(const Model& model, const Operator& op) { + const string& weights_name = op.inputs[1]; + const auto& weights_shape = model.GetArray(weights_name).shape(); + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected) { + return weights_shape.dims(0); + } + if (op.type == OperatorType::kDepthwiseConv) { + return weights_shape.dims(3); + } + LOG(FATAL) << "Unhandled operator type"; + return 0; +} + bool ProcessLinearOperator(Model* model, Operator* op) { if (op->inputs.size() >= 3) { return false; } const string& output_name = op->outputs[0]; + const string& weights_name = op->inputs[1]; + if (!model->GetArray(weights_name).has_shape()) { + return false; + } + const int depth = GetOutputDepthFromWeights(*model, *op); const string& bias_name = AvailableArrayName(*model, output_name + "_bias"); op->inputs.push_back(bias_name); DCHECK_EQ(op->inputs.size(), 3); auto& bias_array = model->GetOrCreateArray(bias_name); bias_array.data_type = ArrayDataType::kFloat; - + bias_array.mutable_shape()->mutable_dims()->push_back(depth); + auto& bias_buffer = bias_array.GetMutableBuffer(); + bias_buffer.data.resize(depth, 0.f); return true; } } // namespace diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index 394fa349e2663e2806344f27a96a5132a2d4a810..75642bbc37be6b3140e5b79a463ca70b5786d772 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -122,7 +122,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, case OperatorType::kFullyConnected: { weights_index = 1; const auto& fc_op = static_cast(op); - CHECK(!fc_op.experimental_shuffled_weights) + CHECK(fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) << "This graph transformation expects to run before FC weights get " "shuffled."; break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 62a09acdfbb553161e480051aa506486b9adec47..4025fede6f160d7ad0fb09be99c246adb93b43a6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -192,7 +192,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather) DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) -DECLARE_GRAPH_TRANSFORMATION(ExperimentalShuffleFCWeights) +DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights) class PropagateDefaultMinMax : public GraphTransformation { public: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index bda6dce22be0f0ca83eb8339ad17573b0267c18c..39f55208e453bdd946cfc8bbbacdc05b793c5d99 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -133,24 +133,20 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { } bool HardcodeMinMaxForSplit(Model* model, Operator* op) { - for (const auto& output : op->outputs) { - if (model->GetArray(output).minmax) { - LOG(WARNING) << "Skipping min-max setting for " << LogName(*op) - << " because output " << output << " already has min-max."; - return false; - } - } // Data is in second input. auto& input_array = model->GetArray(op->inputs[1]); if (!input_array.minmax) { return false; - } else { - for (const auto& output : op->outputs) { - auto& array = model->GetArray(output); + } + bool changed = false; + for (const auto& output : op->outputs) { + auto& array = model->GetArray(output); + if (!array.minmax || !(array.GetMinMax() == input_array.GetMinMax())) { + changed = true; array.GetOrCreateMinMax() = *input_array.minmax; } - return true; } + return changed; } // The output of average or max pooling is within the same range as its input. @@ -353,7 +349,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForConcatenation(model, op); break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: changed = HardcodeMinMaxForSplit(model, op); break; @@ -366,7 +362,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kPad: case OperatorType::kGather: case OperatorType::kTranspose: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index 419a0776a6b987a18df059d3c1d4bf4370cd24d8..b78efd7fc3602dc2d6e03fd28d694c344b61c17c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -44,10 +44,9 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { const auto* div_or_mul_op = div_it->get(); OperatorType expected_op_type_producing_div_or_mul_input; if (div_or_mul_op->type == OperatorType::kDiv) { - expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kSqrt; } else if (div_or_mul_op->type == OperatorType::kMul) { - expected_op_type_producing_div_or_mul_input = - OperatorType::kTensorFlowRsqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt; } else { return false; } @@ -75,8 +74,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* add_op = nullptr; Operator* op_producing_add_input = nullptr; if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || - op_producing_sqrt_or_rsqrt_input->type == - OperatorType::kTensorFlowMaximum) { + op_producing_sqrt_or_rsqrt_input->type == OperatorType::kMaximum) { add_op = op_producing_sqrt_or_rsqrt_input; bool add_can_be_removed = false; CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); @@ -113,7 +111,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* sum_op = add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; - if (sum_op->type != OperatorType::kTensorFlowSum) { + if (sum_op->type != OperatorType::kSum) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Sum op, got %s", @@ -122,7 +120,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { } Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc index e4d52476c649de53b3ab663f53ce7a5538dbb5ab..705e73779b7f74698149d5e9e56f69a371326ceb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -41,7 +41,7 @@ std::vector>::iterator FindOperator( bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const auto sqrt_it = model->operators.begin() + op_index; const auto* sqrt_op = sqrt_it->get(); - if (sqrt_op->type != OperatorType::kTensorFlowSqrt) { + if (sqrt_op->type != OperatorType::kSqrt) { return false; } @@ -52,6 +52,13 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const Operator* square_op; Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]); + if (prev_to_sqrt_op == nullptr) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected AveragePool op, but Sqrt op has no preceding op"); + return false; + } + if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " @@ -65,7 +72,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { square_op = GetOpWithOutput(*model, avpool_op->inputs[0]); CHECK_EQ(square_op->inputs.size(), 1); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index e9842524c829b839b97b3453a36c41efe186efbb..3ca7f53512bb7e307f9a2bc5cfb7c27b45cc052c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -35,21 +35,6 @@ std::vector>::iterator FindOperator( return it; } -bool GetStateArrayForBackEdge(const Model& model, - const string& back_edge_source_array, - string* state_array = nullptr) { - for (const auto& rnn_state : model.flags.rnn_states()) { - if (back_edge_source_array == rnn_state.back_edge_source_array()) { - // Found LSTM cell output - if (state_array) { - *state_array = rnn_state.state_array(); - } - return true; - } - } - return false; -} - // Returns true if the given operator has exactly 1 input, and is connected to // the given op_type. // We use kNone to indicate an input unattached to an operator output. Usually @@ -231,11 +216,6 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { &state_combine_add)) { return false; } - string prev_state; - if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0], - &prev_state)) { - return false; - } // State forget & remember addition Operator *state_forget_mul, *state_remember_mul; @@ -244,9 +224,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { &state_remember_mul)) { return false; } - if (state_forget_mul->inputs[0] != prev_state) { - return false; - } + const string prev_state = state_forget_mul->inputs[0]; // State forget gate Operator* state_forget_sig; @@ -266,26 +244,26 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { // State remember "information" activation function Operator* fc_output_split; - if (!MatchOperatorInputs(*state_info_tanh, *model, - OperatorType::kTensorFlowSplit, &fc_output_split)) { + if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit, + &fc_output_split)) { return false; } // State remember gate activation function Operator* tmp; - if (!MatchOperatorInputs(*state_remember_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // State forget gate activation function - if (!MatchOperatorInputs(*state_forget_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // Fully connected output activation function - if (!MatchOperatorInputs(*fc_output_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } @@ -306,8 +284,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { return false; } - if (static_cast(fully_connected) - ->experimental_shuffled_weights) { + if (static_cast(fully_connected)->weights_format != + FullyConnectedWeightsFormat::kDefault) { // Not yet implemented: experimental shuffled weights in fused LSTM cell. return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index e6e3dfa1de9c9fdd5e759fd547d11a7b8c95d837..46d1fce50e5d6e2a74cf5461d731e46469dde5bf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -74,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { lstm_cell_op->inputs[kInputTensor] = curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT]; + // Previous states. + lstm_cell_op->inputs[kInputActivationStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]; + lstm_cell_op->inputs[kInputCellStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT]; + // Get original weight tensor and decompose 1 tensor to 8 sub tensors. Array& kernel = model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]); @@ -160,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Erase curr lstm op being replaced. DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model); DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT], - model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT], - model); model->operators.erase(FindOp(*model, curr_op)); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index bddb563206f763a756685d196836fa41825cf045..94820a016622a12654e91967737e05fc91ed404c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -60,24 +60,22 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { // Follow sequences of min+max and max+min. First get the leading op. const auto op_it = model->operators.begin() + op_index; const auto* op_0 = op_it->get(); - if (op_0->type != OperatorType::kTensorFlowMinimum && - op_0->type != OperatorType::kTensorFlowMaximum) { + if (op_0->type != OperatorType::kMinimum && + op_0->type != OperatorType::kMaximum) { return false; } // Get the paired op and ensure it's the counter to the first. const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]); if (!op_1 || - (op_1->type != OperatorType::kTensorFlowMinimum && - op_1->type != OperatorType::kTensorFlowMaximum) || + (op_1->type != OperatorType::kMinimum && + op_1->type != OperatorType::kMaximum) || op_0->type == op_1->type) { return false; } - const auto* min_op = - op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1; - const auto* max_op = - op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1; + const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1; + const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1; if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h index 1c32a781698ec78003ebbf9caff28557924323e5..6d8603a1133a7478647b8bcc49ea1eceba28df31 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs { kOutputGateBiasTensor = 15, kProjectionWeightsTensor = 16, // Optional kProjectionBiasTensor = 17, // Optional - kExtendedLstmInputCount = 18 + kInputActivationStateTensor = 18, + // The op can handle 18 inputs or 20 inputs. + kInputCellStateTensor = 19, + kExtendedLstmInputCount = 20, }; enum ExtendedLstmCellOutputs { + // TODO(ycling): Make the 2 output state tensors optional. kOutputStateTensor = 0, kCellStateTensor = 1, kOutputTensor = 2, diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index 5065004093434475172a39efdcfd26c10c49148b..95bc7f7d4b8b517c1cc5a73b3e85bbd985ce460f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -106,7 +106,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; auto* reshape_op = ConvertOperator( - it->get(), OperatorType::kTensorFlowReshape); + it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 92d283ca2cc7069f4b80c95ffdadcad81a884cbf..00ab7cbaa90b399ca08bdfba82991fbd5d2c9f7e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -56,22 +56,22 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // These operators unconditionally produce float outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); break; - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowGreaterEqual: - case OperatorType::kTensorFlowEqual: - case OperatorType::kTensorFlowNotEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; case OperatorType::kRank: - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: // These operators only produce int32 outputs. SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); break; - case OperatorType::kTensorFlowSplit: - case OperatorType::kTensorFlowConcat: + case OperatorType::kSplit: + case OperatorType::kConcat: case OperatorType::kFill: { // These operators produce an output with the same type as their 2nd input CHECK_GE(op->inputs.size(), 2); @@ -135,7 +135,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32; break; } - case OperatorType::kTensorFlowUnsupported: { + case OperatorType::kUnsupported: { auto* unsupported_op = static_cast(op); // Some output tensors from the op could be eliminated by optimization. // This can make unsupported_op->output_data_types have more elements than @@ -175,6 +175,14 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type); break; } + case OperatorType::kPow: { + CHECK_EQ(op->inputs.size(), 2); + CHECK(model->GetArray(op->inputs[0]).data_type == + model->GetArray(op->inputs[1]).data_type); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 77c08868117382f9daf900da79286e9f9e06d9db..0f2592d05f6e01599735c5138c53ba7779ce805d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -27,11 +27,21 @@ namespace toco { namespace { -void ChangeArrayDataType(GraphTransformation* transformation, Array* array, +bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, ArrayDataType new_data_type, const MinMax* new_minmax) { + // The code below assumes kInt16, see + // GetQuantizationParamsFromMinMax + if (new_data_type != ArrayDataType::kInt16) { + return false; + } + + bool changed = false; // Ensure the array ends up in the new type (if it hasn't yet been quantized). - array->final_data_type = new_data_type; + if ((array->final_data_type != new_data_type)) { + array->final_data_type = new_data_type; + changed = true; + } if (array->minmax && array->quantization_params) { // The array is already quantized and has min/max info. @@ -70,10 +80,10 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array, // Directly change the type as the array was already quantized. array->data_type = new_data_type; - } else { + changed = true; + } else if (!array->quantization_params) { // Array has not yet been quantized so we can just set the final data type // and assign the new min/max value (if provided). - CHECK(!array->quantization_params); if (!array->minmax && new_minmax) { transformation->AddMessageF("Forcing new minmax to %g,%g (%s)", @@ -82,16 +92,18 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array, auto& array_minmax = array->GetOrCreateMinMax(); array_minmax.min = new_minmax->min; array_minmax.max = new_minmax->max; + changed = true; } } + return changed; } // Returns true if the op blocks our backward recursive data type propagation. bool DoesOpBlockBackwardPropagation(const Operator& op) { switch (op.type) { case OperatorType::kConcatenation: - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Concat shouldn't block propagation, but we do expect that all inputs // have the same range. return false; @@ -100,10 +112,10 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { // FakeQuant so make sure we move across them. case OperatorType::kGather: // Gathers need their parameters changed to the appropriate data type. - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: case OperatorType::kSelect: - case OperatorType::kTensorFlowTile: + case OperatorType::kTile: // Reshapes and transposes don't change values. return false; default: @@ -121,11 +133,11 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { // Ignore gather indices. return input_index != 0; break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: // Ignore reshape/transpose shapes/dimensions. return input_index != 0; - case OperatorType::kTensorFlowTile: + case OperatorType::kTile: // Ignore tile multiples. return input_index != 0; default: @@ -159,9 +171,8 @@ bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation, "Adjusting input final data type of array %s from %s to %s", input, ArrayDataTypeName(input_array.final_data_type), ArrayDataTypeName(new_data_type)); - did_change = true; - ChangeArrayDataType(transformation, &input_array, new_data_type, - &new_minmax); + did_change |= ChangeArrayDataType(transformation, &input_array, + new_data_type, &new_minmax); // Walk up into all ops producing the inputs to this op. for (auto& producing_op : model->operators) { @@ -212,9 +223,8 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, "Adjusting output final data type of array %s from %s to %s", output, ArrayDataTypeName(output_array.final_data_type), ArrayDataTypeName(new_data_type)); - did_change = true; - ChangeArrayDataType(transformation, &output_array, new_data_type, - nullptr); + did_change |= ChangeArrayDataType(transformation, &output_array, + new_data_type, nullptr); // Walk down into all ops consuming the output of this op. for (auto& consuming_op : model->operators) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index e7da9051d835c30f93838b0c5be45dbcc92a70c1..82b3ab96fe07a7385e678cc9ccfd68ca1d7ce330 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -120,49 +120,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x, CHECK(output_array->has_shape()); } -int GetOutputDepthFromWeights(const Model& model, const Operator& op) { - const string& weights_name = op.inputs[1]; - const auto& weights_shape = model.GetArray(weights_name).shape(); - if (op.type == OperatorType::kConv || - op.type == OperatorType::kFullyConnected) { - return weights_shape.dims(0); - } else if (op.type == OperatorType::kDepthwiseConv) { - return weights_shape.dims(3); - } else { - LOG(FATAL) << "Unhandled operator type"; - } -} - -bool EnsureBiasVectorShape(Model* model, Operator* op) { - const string& weights_name = op->inputs[1]; - const auto& weights_array = model->GetArray(weights_name); - // Yield until weights shape has been resolved. - if (!weights_array.has_shape()) { - return false; - } - - if (op->inputs.size() < 3) { - return false; - } - auto& bias_array = model->GetArray(op->inputs[2]); - if (bias_array.has_shape()) { - return true; - } - - const int output_depth = GetOutputDepthFromWeights(*model, *op); - bias_array.copy_shape(Shape({output_depth})); - - auto& float_buffer = bias_array.GetMutableBuffer(); - float_buffer.data.resize(output_depth, 0); - - return true; -} - void ProcessConvOperator(Model* model, ConvOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; - } - const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -292,10 +250,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { } void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; - } - const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -325,7 +279,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { if (!op->depth_multiplier) { op->depth_multiplier = output_depth / input_depth; } - QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) + CHECK_EQ(output_depth, input_depth * op->depth_multiplier) << "input/output depths and depth_multiplier don't match"; const int kheight = weights_shape.dims(1); @@ -410,10 +364,6 @@ void ProcessOpWithShapeInput(Model* model, Operator* op) { } void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; - } - const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -572,11 +522,11 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kTensorFlowMin: + case OperatorType::kMin: // Reduction Min return static_cast(op).keep_dims; - case OperatorType::kTensorFlowMax: + case OperatorType::kMax: // Reduction Max return static_cast(op).keep_dims; - case OperatorType::kTensorFlowSum: + case OperatorType::kSum: return static_cast(op).keep_dims; case OperatorType::kMean: return static_cast(op).keep_dims; @@ -1341,8 +1291,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { op->begin_mask, op->start_indices, op->strides, input_array.shape().dims().data(), axis); int stop_index = tflite::strided_slice::StopForAxis( - op->end_mask, op->stop_indices, op->strides, - input_array.shape().dims().data(), axis); + op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides, + input_array.shape().dims().data(), axis, start_index); int dim_size = ceil(static_cast(stop_index - start_index) / op->strides[axis]); @@ -1577,14 +1527,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kLogistic: case OperatorType::kTanh: case OperatorType::kLocalResponseNormalization: - case OperatorType::kTensorFlowIdentity: + case OperatorType::kIdentity: case OperatorType::kFakeQuant: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowAll: - case OperatorType::kTensorFlowAssert: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kAll: + case OperatorType::kAssert: case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kExp: @@ -1603,14 +1553,15 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kDiv: case OperatorType::kFloorDiv: case OperatorType::kFloorMod: - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowMaximum: - case OperatorType::kTensorFlowMinimum: - case OperatorType::kTensorFlowGreaterEqual: - case OperatorType::kTensorFlowEqual: - case OperatorType::kTensorFlowNotEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kMaximum: // Element-wise Maximum + case OperatorType::kMinimum: // Element-wise Minimum + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: + case OperatorType::kPow: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: @@ -1643,7 +1594,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessFullyConnectedOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: ProcessTensorFlowReshapeOperator( model, static_cast(op)); break; @@ -1656,9 +1607,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowSum: + case OperatorType::kMin: // Reduction Min + case OperatorType::kMax: // Reduction Max + case OperatorType::kSum: case OperatorType::kMean: ProcessTensorFlowReductionOperator(model, op); break; @@ -1669,26 +1620,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSliceOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowSwitch: + case OperatorType::kSwitch: // We can't know the sizes of the outputs until we have resolved the // predicate, and once we have resolved the predicate, the whole // Switch node will get resolved away. // See ResolveTensorFlowSwitch. break; - case OperatorType::kTensorFlowMerge: + case OperatorType::kMerge: // No need to bother resolving TensorFlow Merge ops: other graph // transformations will remove them anyway. // See ResolveTensorFlowMerge. break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: ProcessTensorFlowSplitOperator(model, static_cast(op)); break; case OperatorType::kSqueeze: ProcessSqueezeOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Unimplemented, hopefully another graph transformation will // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat // will resolve this node to a DepthConcatenation, or else we have @@ -1704,7 +1655,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRank: ProcessRankOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: ProcessShapeOperator(model, static_cast(op)); break; case OperatorType::kStack: @@ -1725,7 +1676,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessLstmCellOperator(model, static_cast(op)); break; case OperatorType::kBatchMatMul: - case OperatorType::kTensorFlowMatMul: + case OperatorType::kMatMul: // MatMul operators are converted to FullyConnected, after which their // shapes are propagated. break; @@ -1750,7 +1701,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kArgMax: ProcessArgMaxOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowUnsupported: + case OperatorType::kUnsupported: break; case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast(op)); @@ -1772,7 +1723,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSparseToDenseOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowTile: + case OperatorType::kTile: ProcessTileOperator(model, static_cast(op)); break; default: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index eca2c701f8bbf889088794c939af7082db0734dd..58885b4950733bfc9d394127e597a08232cd5663 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -33,7 +33,7 @@ namespace { bool SupportsQuantization(const Operator& op) { auto type = op.type; - if (type == OperatorType::kTensorFlowUnsupported) { + if (type == OperatorType::kUnsupported) { auto* unsupported = static_cast(&op); return unsupported->quantized; } @@ -42,15 +42,13 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kConcatenation || type == OperatorType::kL2Normalization || type == OperatorType::kAdd || type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || - type == OperatorType::kTensorFlowMinimum || - type == OperatorType::kTensorFlowMaximum || + type == OperatorType::kMinimum || type == OperatorType::kMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || type == OperatorType::kResizeBilinear || - type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || + type == OperatorType::kSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || - type == OperatorType::kPadV2 || - type == OperatorType::kTensorFlowReshape || + type == OperatorType::kPadV2 || type == OperatorType::kReshape || type == OperatorType::kTanh || type == OperatorType::kMul || type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || @@ -58,11 +56,11 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell || type == OperatorType::kGather || type == OperatorType::kTranspose || type == OperatorType::kMean || - type == OperatorType::kTensorFlowGreater || - type == OperatorType::kTensorFlowGreaterEqual || - type == OperatorType::kTensorFlowLess || - type == OperatorType::kTensorFlowLessEqual || - type == OperatorType::kSelect || type == OperatorType::kArgMax; + type == OperatorType::kGreater || + type == OperatorType::kGreaterEqual || type == OperatorType::kLess || + type == OperatorType::kLessEqual || type == OperatorType::kSelect || + type == OperatorType::kArgMax || type == OperatorType::kRelu || + type == OperatorType::kRelu1 || type == OperatorType::kRelu6; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { @@ -328,14 +326,15 @@ bool ChooseQuantizationForOperatorOutput( output, OperatorTypeName(op.type)); return true; } - if ((op.type == OperatorType::kDepthToSpace) || - (op.type == OperatorType::kSpaceToDepth) || - (op.type == OperatorType::kTensorFlowReshape) || - (op.type == OperatorType::kTensorFlowSplit) || - (op.type == OperatorType::kConcatenation && - model->flags.change_concat_input_ranges())) { + if ((op.type == OperatorType::kConcatenation && + model->flags.change_concat_input_ranges()) || + op.type == OperatorType::kDepthToSpace || + op.type == OperatorType::kSpaceToDepth || + op.type == OperatorType::kReshape || op.type == OperatorType::kSplit || + op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 || + op.type == OperatorType::kRelu6) { int data_input_index = 0; - if (op.type == OperatorType::kTensorFlowSplit) { + if (op.type == OperatorType::kSplit) { data_input_index = 1; } // Copying and rearrangement ops should preserve the quantization parameters @@ -508,36 +507,47 @@ bool Quantize::Run(Model* model, std::size_t op_index) { // Check if the output of that Dequantize op was not used by any // other operator. We will then erase that Dequantize op. if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { - // If any of the model's output_arrays was pointing to the - // Dequantize op's output, let it point to the Dequantize op's - // input instead. - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { - // TODO(b/78013785): never rename output arrays. - if (IsInputArray(*model, dequantize_op->inputs[0])) { - // The op input is an input array and the output is an output - // array and we can't have an array be both. Insert a copy - // op to ensure the two arrays stay separate. - AddMessageF( - "Tried to rename output array %d while removing dequant " - "op %s but array is also an input; inserting copy %s " - "-> %s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantize_op->inputs[0]); - InsertCopyOperator(model, dequantize_op->inputs[0], - dequantize_op->outputs[0]); - } else { - // Op output is strictly used as an output array, so we can - // just rename the array and directly bypass the op. - AddMessageF( - "Renaming output array %d after removing dequant op %s: " - "%s -> %s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantize_op->inputs[0]); - model->flags.set_output_arrays(i, dequantize_op->inputs[0]); - model->EraseArray(dequantize_op->outputs[0]); + if (IsDiscardableArray(*model, dequantize_op->outputs[0])) { + // Usual case: we can just discard the dequantize output. + model->EraseArray(dequantize_op->outputs[0]); + } else { + // The dequantize output is not discardable. Special care needed. + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == + dequantize_op->outputs[0]) { + // TODO(b/78013785): never rename output arrays. + if (IsInputArray(*model, dequantize_op->inputs[0])) { + // The op input is an input array and the output is an + // output array and we can't have an array be both. Insert a + // copy op to ensure the two arrays stay separate. + AddMessageF( + "Tried to rename output array %d while removing " + "dequant " + "op %s but array is also an input; inserting copy %s " + "-> %s", + i, LogName(*dequantize_op), + model->flags.output_arrays(i), + dequantize_op->inputs[0]); + InsertCopyOperator(model, dequantize_op->inputs[0], + dequantize_op->outputs[0]); + } else { + // Op output is strictly used as an output array, so we can + // just rename the array and directly bypass the op. + AddMessageF( + "Renaming output array %d after removing dequant op " + "%s: " + "%s -> %s", + i, LogName(*dequantize_op), + model->flags.output_arrays(i), + dequantize_op->inputs[0]); + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + model->EraseArray(dequantize_op->outputs[0]); + } + break; } - break; } } model->operators.erase(dequantize_it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc index 35a0c465327f352863350e7a8af714d16b7be393..73ad326299bbd929afbb8dda2c41b97a126afbe1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -26,7 +26,7 @@ namespace toco { bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { const auto assert_it = model->operators.begin() + op_index; const auto* assert_op = assert_it->get(); - if (assert_op->type != OperatorType::kTensorFlowAssert) { + if (assert_op->type != OperatorType::kAssert) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc index 404269bbfd9312bbbab32489783d9e4217ecbd89..7ec7752f25dad1c24b821733c0e6dafbd1cd8bf2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -28,7 +28,7 @@ namespace toco { bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { const auto passthru_it = model->operators.begin() + op_index; const auto* passthru_op = passthru_it->get(); - if (passthru_op->type != OperatorType::kTensorFlowIdentity) { + if (passthru_op->type != OperatorType::kIdentity) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index a950fe6442bc656b725a1f0687f4c024f4fb0f84..9f5d8b94507ec11957c3ae55ffca510eeb81ac89 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -97,7 +97,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, "Cannot remove %s, neither its main input nor its output may be " "discarded", LogName(*passthru_op)); - if (passthru_op->type != OperatorType::kTensorFlowReshape && + if (passthru_op->type != OperatorType::kReshape && model->GetArray(main_input_name).has_shape()) { // We can't remove either array but we can remove the op. Converting it to // a reshape gives us some hope of later on fixing that (either in the diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc index eaee1c662b7cedb2baec7be47e12e348c3e7b25c..142c876b154755ac9c6b93e560f22ec8d6ec6563 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc @@ -47,11 +47,11 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, double clamp_min; double clamp_max; switch (op_type) { - case OperatorType::kTensorFlowMinimum: + case OperatorType::kMinimum: // Element-wise Minimum clamp_min = -std::numeric_limits::infinity(); clamp_max = clamp_value; break; - case OperatorType::kTensorFlowMaximum: + case OperatorType::kMaximum: // Element-wise Maximum clamp_min = clamp_value; clamp_max = std::numeric_limits::infinity(); break; @@ -72,8 +72,8 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; auto* op = it->get(); - if ((op->type != OperatorType::kTensorFlowMinimum && - op->type != OperatorType::kTensorFlowMaximum) || + if ((op->type != OperatorType::kMinimum && + op->type != OperatorType::kMaximum) || op->inputs.size() != 2) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index e28d8cf01eafee64e08ac2cc4b43ea7c227456c2..404f27e067402474484d3ee8e23595fb9f93a6c9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -30,7 +30,7 @@ namespace { bool IsReshapeTrivial(const Model& model, const Operator& op, RemoveTrivialReshape* transformation) { - CHECK(op.type == OperatorType::kTensorFlowReshape); + CHECK(op.type == OperatorType::kReshape); // One way in which a reshape can be trivial is if its // output shape is == its input shape @@ -58,7 +58,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, // is only consumed by another reshape. if (CountOpsWithInput(model, op.outputs[0]) == 1) { const auto* next_op = GetOpWithInput(model, op.outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { transformation->AddMessageF( "%s is trivial because its output is only consumed by another " "Reshape op %s", @@ -75,7 +75,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index 1956ab2d2021cda84a0d715534923d6174c30dd1..dde91234a8240f4518cd105c2cc4e79102735980 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -48,7 +48,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { for (const auto& rnn_state : model->flags.rnn_states()) { if (output == rnn_state.state_array()) { CHECK(op->type == OperatorType::kFill || - op->type == OperatorType::kTensorFlowIdentity); + op->type == OperatorType::kIdentity); found_output_as_rnn_state_array = true; break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 9f5b7920cb937b021eb23fc1d5fdc3c1ff18a72d..550de83018f25a7aa4da82707fedb86434615fb0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -37,8 +37,8 @@ bool IsElementwiseOperator(OperatorType optype) { case OperatorType::kRelu1: case OperatorType::kRelu6: case OperatorType::kTanh: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: + case OperatorType::kSqrt: + case OperatorType::kSquare: return true; default: return false; @@ -51,7 +51,7 @@ bool IsMoveOperator(OperatorType optype) { case OperatorType::kExpandDims: case OperatorType::kSpaceToDepth: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: return true; default: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc index 9e7fe1b1ccd851dd998e59e75ff798f52f7c6e5a..c907a597cb719b68dbf36868a75e49a7c5181423 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -123,8 +123,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { } TensorFlowReshapeOperator* reshape_op = - ConvertOperator( - reshape_it->get(), OperatorType::kTensorFlowReshape); + ConvertOperator(reshape_it->get(), + OperatorType::kReshape); if (reshape_op == nullptr) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index a06919e228dc2084f8943a714a0ca111d013c159..b8b35161d77e5b6dd8c30e03959dba3c60d1d56c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -50,7 +50,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { // will delete this op. return false; } - std::vector crops_buffer = + const std::vector& crops_buffer = crops_array.GetBuffer().data; for (int i = 0; i < crops_dims[0]; ++i) { op->before_crops.push_back(crops_buffer[i * 2]); @@ -62,7 +62,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); - std::vector block_shape_buffer = + const std::vector& block_shape_buffer = block_shape_array.GetBuffer().data; for (int i = 0; i < block_shape_dims[0]; ++i) { op->block_shape.push_back(block_shape_buffer[i]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index 6e78653fad238085da5ba66166884093ea9b0214..f7e5aa6609bd4f7eb2a95750125e30a7803b36e1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -145,17 +145,17 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, outval = floor(val0 / val1); } else if (binary_op->type == OperatorType::kFloorMod) { outval = val0 - (floor(val0 / val1) * val1); - } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + } else if (binary_op->type == OperatorType::kMinimum) { outval = std::min(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + } else if (binary_op->type == OperatorType::kMaximum) { outval = std::max(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowLess) { + } else if (binary_op->type == OperatorType::kLess) { outval = val0 < val1; - } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + } else if (binary_op->type == OperatorType::kLessEqual) { outval = val0 <= val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + } else if (binary_op->type == OperatorType::kGreater) { outval = val0 > val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + } else if (binary_op->type == OperatorType::kGreaterEqual) { outval = val0 >= val1; } else { LOG(FATAL) << "should not get here"; @@ -198,12 +198,12 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kDiv && binary_op->type != OperatorType::kFloorDiv && binary_op->type != OperatorType::kFloorMod && - binary_op->type != OperatorType::kTensorFlowMinimum && - binary_op->type != OperatorType::kTensorFlowMaximum && - binary_op->type != OperatorType::kTensorFlowLess && - binary_op->type != OperatorType::kTensorFlowLessEqual && - binary_op->type != OperatorType::kTensorFlowGreater && - binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + binary_op->type != OperatorType::kMinimum && + binary_op->type != OperatorType::kMaximum && + binary_op->type != OperatorType::kLess && + binary_op->type != OperatorType::kLessEqual && + binary_op->type != OperatorType::kGreater && + binary_op->type != OperatorType::kGreaterEqual) { return false; } CHECK_EQ(binary_op->inputs.size(), 2); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc index 7e7ad383e7789891f5396845241e70143dc8b76f..41562ab393694d76c5cb6c5df5f7df2a71f893f5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -25,7 +25,7 @@ namespace toco { bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); - if (base_op->type != OperatorType::kTensorFlowReshape) { + if (base_op->type != OperatorType::kReshape) { return false; } const auto* op = static_cast(base_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 9ea01acd05364224ce219bed533c999793a2a2f1..8a0e3e8995839a737b5671701a97b514b0fc7bf1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -22,8 +22,7 @@ namespace toco { bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; const auto* op = it->get(); - if (!(op->type == OperatorType::kTensorFlowShape || - op->type == OperatorType::kRank)) { + if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) { return false; } @@ -48,7 +47,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { // Compute the output CHECK(!output_array.buffer); auto& output_buffer = output_array.GetMutableBuffer(); - if (op->type == OperatorType::kTensorFlowShape) { + if (op->type == OperatorType::kShape) { // Copy the input shape into the output buffer. output_buffer.data = input_array.shape().dims(); } else if (op->type == OperatorType::kRank) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc index 69db1942cd52af810acf38a818997c71122d8500..a4d5f1923a1dffdff1ef51eb5317fa5794a8bc27 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc @@ -41,7 +41,7 @@ void Stack(Model* model, StackOperator const& op) { const auto& input_array = model->GetArray(op.inputs[i]); int input_size = RequiredBufferSizeForShape(input_array.shape()); memcpy(&output_data[dst_offset], &input_array.GetBuffer().data[0], - input_size * sizeof(Type)); + input_size * ElementSize(Type)); dst_offset += input_size; } CHECK_EQ(dst_offset, output_data.size()); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 1dd52e906900e997f282740404a81b9fcd21e867..9d8bd4fc39344a4ea1fa4942a2a99ec535b5bee8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -38,6 +38,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, CHECK_EQ(op.new_axis_mask, 0); int num_input_axes = op.start_indices.size(); + CHECK_EQ(num_input_axes, op.start_indices.size()); CHECK_EQ(num_input_axes, op.stop_indices.size()); CHECK_EQ(num_input_axes, op.strides.size()); @@ -49,11 +50,16 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, // Initialize source coordinate Shape const& input_shape = input_array.shape(); Buffer const& input_buffer = input_array.GetBuffer(); - std::vector src_coord(op.start_indices.size()); + std::vector src_coord(num_input_axes); + std::vector stop_for_axis(num_input_axes); for (int axis = 0; axis < num_input_axes; axis++) { - src_coord[axis] = tflite::strided_slice::StartForAxis( + int start = tflite::strided_slice::StartForAxis( op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(), axis); + src_coord[axis] = start; + stop_for_axis[axis] = tflite::strided_slice::StopForAxis( + op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides, + input_shape.dims().data(), axis, start); } // In order to handle any number (N) of dimensions, we copy elements one by @@ -76,9 +82,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, } // Check if we've overflowed. - int stop = tflite::strided_slice::StopForAxis( - op.end_mask, op.stop_indices, op.strides, input_shape.dims().data(), - axis); + int stop = stop_for_axis[axis]; if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { // Reset axis and set carry src_coord[axis] = tflite::strided_slice::StartForAxis( @@ -155,14 +159,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { break; } - // Erase input array if no longer used - if (IsDiscardableArray(*model, op->inputs[0]) && - CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->EraseArray(op->inputs[0]); - } - - // Erase the operator - model->operators.erase(it); + DeleteOpAndArraysIfUnused(model, it->get()); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index f6c8f79d8d3311dc2294e3ec406a184b2a16a6b5..f89ef85fdb63ca4906c7f016e86bb1f9d8a7099a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -53,13 +53,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kCast: case OperatorType::kLog: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowSum: - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowReshape: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kSum: + case OperatorType::kMin: // Reduction Min + case OperatorType::kMax: // Reduction Max + case OperatorType::kReshape: case OperatorType::kRelu6: case OperatorType::kRelu1: case OperatorType::kRelu: @@ -103,7 +103,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { // The min-max is only copied for ops that copy data without arithmetic. // In future trivial transpose, etc, can be handled here. - if (unary_op->type == OperatorType::kTensorFlowReshape) { + if (unary_op->type == OperatorType::kReshape) { CopyMinMaxFromFirstInput(*unary_op, model); } @@ -164,10 +164,10 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = outval; } - } else if (unary_op->type == OperatorType::kTensorFlowReshape) { + } else if (unary_op->type == OperatorType::kReshape) { CHECK(input_buffer_size == output_buffer_size); output_float_data = *input_float_data; - } else if (unary_op->type == OperatorType::kTensorFlowSum) { + } else if (unary_op->type == OperatorType::kSum) { CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { AddMessageF("Axis input is non-constant"); @@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = sum; } - } else if (unary_op->type == OperatorType::kTensorFlowMin) { + } else if (unary_op->type == OperatorType::kMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -207,7 +207,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { min = std::min(min, (*input_float_data)[i]); } output_float_data[0] = min; - } else if (unary_op->type == OperatorType::kTensorFlowMax) { + } else if (unary_op->type == OperatorType::kMax) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -220,9 +220,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { output_float_data[0] = max; } else if (unary_op->type == OperatorType::kNeg || unary_op->type == OperatorType::kLog || - unary_op->type == OperatorType::kTensorFlowRsqrt || - unary_op->type == OperatorType::kTensorFlowSqrt || - unary_op->type == OperatorType::kTensorFlowSquare) { + unary_op->type == OperatorType::kRsqrt || + unary_op->type == OperatorType::kSqrt || + unary_op->type == OperatorType::kSquare) { // Element-wise ops. Should have perfectly matching sizes here. for (int i = 0; i < output_dims_count; i++) { CHECK_EQ(output_shape.dims(i), input_shape.dims(i)); @@ -235,11 +235,11 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { outval = -val; } else if (unary_op->type == OperatorType::kLog) { outval = std::log(val); - } else if (unary_op->type == OperatorType::kTensorFlowRsqrt) { + } else if (unary_op->type == OperatorType::kRsqrt) { outval = 1.0f / std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSqrt) { + } else if (unary_op->type == OperatorType::kSqrt) { outval = std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSquare) { + } else if (unary_op->type == OperatorType::kSquare) { outval = val * val; } else { LOG(FATAL) << "should not get here."; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index bc70db0bd8c26319fa140616de96452260a01058..8266e2c205b65e9d8a969643f102bb852be9125b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -51,11 +51,12 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order, } bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { - auto reorder_it = model->operators.begin() + op_index; - auto* reorder_op = static_cast(reorder_it->get()); - if (reorder_op->type != OperatorType::kReorderAxes) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->type != OperatorType::kReorderAxes) { return false; } + auto* reorder_op = static_cast(op); const auto& input_array_name = reorder_op->inputs[0]; const auto& output_array_name = reorder_op->outputs[0]; auto& input_array = model->GetArray(input_array_name); @@ -95,7 +96,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { // Remove the op and output array. model->EraseArray(output_array_name); - model->operators.erase(reorder_it); + model->operators.erase(it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc index 2e063e35548aa5e51c3bcc94a2dfc7992180d014..b615c9a545695e5d14fa5809e0c38a770f23ea24 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -28,7 +28,7 @@ namespace toco { bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index dad6aceccfd201b3db07c29c99a8c6ef75bb89a1..fab50bec1fc5ec50cecba53845457931ed59c0b8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -53,7 +53,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { // will delete this op. return false; } - std::vector paddings_buffer = + const std::vector& paddings_buffer = paddings_array.GetBuffer().data; for (int i = 0; i < paddings_dims[0]; ++i) { op->before_paddings.push_back(paddings_buffer[i * 2]); @@ -66,7 +66,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); - std::vector block_shape_buffer = + const std::vector& block_shape_buffer = block_shape_array.GetBuffer().data; for (int i = 0; i < block_shape_dims[0]; ++i) { op->block_shape.push_back(block_shape_buffer[i]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index dd3e73635ae0215510f0a8d1aee487da5af35700..e8bb85704e1c750300079681b5a12f6a488b6b48 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -36,7 +36,7 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) { // If the output is consumed by a reshape op, it's a trivial squeeze. if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { AddMessageF( "%s is trivial because its output is only consumed by a " "Reshape op", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index 5c0c1e3478fa0d94104d1b76bab176b98b314c50..fa5ee899334bdf2d39a6861b0e0c4548142e9d2a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -28,8 +28,8 @@ namespace toco { bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { auto concat_it = model->operators.begin() + op_index; const auto* tf_concat_op = concat_it->get(); - if (tf_concat_op->type != OperatorType::kTensorFlowConcat && - tf_concat_op->type != OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type != OperatorType::kConcat && + tf_concat_op->type != OperatorType::kConcatV2) { return false; } @@ -38,7 +38,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { // of inputs: in Concat,the axis is the first input, while in // ConcatV2, it is the last input. std::size_t axis_pos = 0; - if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type == OperatorType::kConcatV2) { axis_pos = tf_concat_op->inputs.size() - 1; } const string axis_name = tf_concat_op->inputs[axis_pos]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index 2a236d3f98784e8244942f94d5a250b5bc00a8ad..d496f5ae5eeeca5063e23b25498b0ac450e9f946 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -26,7 +26,7 @@ namespace toco { bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { auto matmul_it = model->operators.begin() + op_index; - if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { + if (matmul_it->get()->type != OperatorType::kMatMul) { return false; } const auto* matmul_op = @@ -97,7 +97,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { // MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if // the input doesn't need reshaping, so we can't just match (Reshape, MatMul) // pairs. - if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { + if (previous_op && previous_op->type == OperatorType::kReshape) { AddMessageF("Combining %s and %s into %s", LogName(*previous_op), LogName(*matmul_op), LogName(*fc_op)); const auto& previous_op_output = previous_op->outputs[0]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index 38e0005890ac10410df4ddb5290be8fcc948c349..4edffe3d48fd880c0261b34fc407b8e2ac66ccb9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { const auto merge_it = model->operators.begin() + op_index; const auto* merge_op = merge_it->get(); - if (merge_op->type != OperatorType::kTensorFlowMerge) { + if (merge_op->type != OperatorType::kMerge) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index a418073441f1241a5acb1164b36f332828ea2e99..da8e7a2d1c06cf89b9708b404da7667565245f8f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { const auto switch_it = model->operators.begin() + op_index; const auto* switch_op = switch_it->get(); - if (switch_op->type != OperatorType::kTensorFlowSwitch) { + if (switch_op->type != OperatorType::kSwitch) { return false; } @@ -92,7 +92,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { if (*input_it == switch_op->outputs[nonselected_output_index]) { // Let us guard our assumption that only Merge nodes consume the outputs // of Switch nodes: - CHECK(other_op->type == OperatorType::kTensorFlowMerge); + CHECK(other_op->type == OperatorType::kMerge); input_it = other_op->inputs.erase(input_it); } else { ++input_it; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc rename to tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc index c00cdcb944b085dda41033b95c96537cc2e047c3..22c258cec5fde4144c4b048d5ec60a8604362cbb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc @@ -24,14 +24,14 @@ limitations under the License. namespace toco { -bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) { +bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { Operator* op = model->operators[op_index].get(); if (op->type != OperatorType::kFullyConnected) { return false; } FullyConnectedOperator* fc_op = static_cast(op); // Exit if this FC op already has shuffled weights - if (fc_op->experimental_shuffled_weights) { + if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) { return false; } const Array& input_array = model->GetArray(fc_op->inputs[0]); @@ -135,7 +135,7 @@ bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) { CHECK_EQ(shuffled_data_ptr, shuffled_data.data() + rows * cols); // Switch this FC op to using the shuffled weights. weights_data = std::move(shuffled_data); - fc_op->experimental_shuffled_weights = true; + fc_op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; AddMessageF("Applied experimental shuffling to the weights of %s", LogName(*op)); // Add a second output array to this FC op, serving as a workspace to perform diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index cd4f034dfea57b6d379b67a90ba4fa3fe3d615d5..55e39d963f97eb35790b460ed8c634b32abf490f 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -263,7 +263,11 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -296,7 +300,11 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -328,8 +336,12 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int64_val_size()) { - for (int i = 0; i < input_tensor.int64_val_size(); i++) { + if (input_tensor.int64_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int64_val(0); + } + } else if (input_tensor.int64_val_size() == input_flat_size) { + for (int i = 0; i < input_tensor.float_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); } } else if (input_tensor.tensor_content().size() == @@ -362,7 +374,11 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()), false); CHECK_GE(output_bool_data.size(), input_flat_size); - if (input_tensor.bool_val_size()) { + if (input_tensor.bool_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_bool_data[i] = input_tensor.bool_val(0); + } + } else if (input_tensor.bool_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.bool_val_size(); i++) { output_bool_data[i] = input_tensor.bool_val(i); } @@ -426,18 +442,19 @@ int GetInputsCount(const NodeDef& node, return i; } } - return node.input_size(); - } else { - return node.input_size(); } + return node.input_size(); } -void CheckInputsCount(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - int expected_input_count) { - QCHECK_EQ(GetInputsCount(node, tf_import_flags), expected_input_count) - << node.op() << " node expects " << expected_input_count - << " input(s) other than control dependencies: " << node.DebugString(); +tensorflow::Status CheckInputsCount( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + int expected_input_count) { + if (GetInputsCount(node, tf_import_flags) != expected_input_count) { + return tensorflow::errors::FailedPrecondition( + node.op(), " node expects ", expected_input_count, + " input(s) other than control dependencies: ", node.DebugString()); + } + return tensorflow::Status::OK(); } template @@ -504,7 +521,7 @@ tensorflow::Status ConvertConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "Conv2D"); - CheckInputsCount(node, tf_import_flags, 2); + TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -574,11 +591,11 @@ tensorflow::Status ConvertConvOperator( return tensorflow::Status::OK(); } -void ConvertDepthwiseConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthwiseConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthwiseConv2dNative"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -625,13 +642,14 @@ void ConvertDepthwiseConvOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(conv); + return tensorflow::Status::OK(); } -void ConvertDepthToSpaceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthToSpaceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthToSpace"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); auto* op = new DepthToSpaceOperator; @@ -640,13 +658,14 @@ void ConvertDepthToSpaceOperator(const NodeDef& node, op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToDepthOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToDepthOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToDepth"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); tensorflow::DataType dtype = GetDataTypeAttr(node, "T"); if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 && @@ -662,13 +681,14 @@ void ConvertSpaceToDepthOperator(const NodeDef& node, op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBiasAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBiasAddOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BiasAdd"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto& input_name = node.input(0); const auto& bias_name = node.input(1); @@ -678,13 +698,14 @@ void ConvertBiasAddOperator(const NodeDef& node, biasadd->inputs.push_back(bias_name); biasadd->outputs.push_back(node.name()); model->operators.emplace_back(biasadd); + return tensorflow::Status::OK(); } -void ConvertRandomUniform(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRandomUniform( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "RandomUniform"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32); auto op = absl::make_unique(); @@ -695,11 +716,12 @@ void ConvertRandomUniform(const NodeDef& node, op->seed2 = GetIntAttr(node, "seed2"); CHECK(model != nullptr); model->operators.emplace_back(std::move(op)); + return tensorflow::Status::OK(); } -void ConvertIdentityOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertIdentityOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient"); auto* op = new TensorFlowIdentityOperator; @@ -716,13 +738,14 @@ void ConvertIdentityOperator(const NodeDef& node, op->inputs.push_back(input_name); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxArgs( +tensorflow::Status ConvertFakeQuantWithMinMaxArgs( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new FakeQuantOperator; op->inputs.push_back(node.input(0)); op->minmax.reset(new MinMax); @@ -733,9 +756,10 @@ void ConvertFakeQuantWithMinMaxArgs( // tf.fake_quant_with_min_max_args num_bits defaults to 8. op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxVars( +tensorflow::Status ConvertFakeQuantWithMinMaxVars( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); @@ -751,14 +775,14 @@ void ConvertFakeQuantWithMinMaxVars( op->outputs.push_back(node.name()); op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - -void ConvertSqueezeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSqueezeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Squeeze"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new SqueezeOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); @@ -772,13 +796,14 @@ void ConvertSqueezeOperator(const NodeDef& node, } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSumOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Sum"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSumOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -787,13 +812,14 @@ void ConvertSumOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertSplitOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSplitOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Split"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSplitOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -804,13 +830,14 @@ void ConvertSplitOperator(const NodeDef& node, } op->num_split = num_split; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSwitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSwitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Switch"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSwitchOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -818,13 +845,14 @@ void ConvertSwitchOperator(const NodeDef& node, // Switch operators have two outputs: "name" and "name:1". op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSoftmaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Softmax"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* softmax = new SoftmaxOperator; softmax->inputs.push_back(input_name); @@ -833,13 +861,14 @@ void ConvertSoftmaxOperator(const NodeDef& node, CHECK(!node.attr().count("beta")); // Stab in the dark, just in case. softmax->beta = 1.f; model->operators.emplace_back(softmax); + return tensorflow::Status::OK(); } -void ConvertLRNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertLRNOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "LRN"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* lrn = new LocalResponseNormalizationOperator; lrn->inputs.push_back(input_name); @@ -849,13 +878,14 @@ void ConvertLRNOperator(const NodeDef& node, lrn->alpha = GetFloatAttr(node, "alpha"); lrn->beta = GetFloatAttr(node, "beta"); model->operators.emplace_back(lrn); + return tensorflow::Status::OK(); } -void ConvertMaxPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMaxPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "MaxPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -891,13 +921,14 @@ void ConvertMaxPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(maxpool); + return tensorflow::Status::OK(); } -void ConvertAvgPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertAvgPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "AvgPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -929,13 +960,13 @@ void ConvertAvgPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(avgpool); + return tensorflow::Status::OK(); } - -void ConvertBatchMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertBatchMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); @@ -945,12 +976,13 @@ void ConvertBatchMatMulOperator(const NodeDef& node, batch_matmul->inputs = {node.input(0), node.input(1)}; batch_matmul->outputs = {node.name()}; model->operators.emplace_back(batch_matmul); + return tensorflow::Status::OK(); } -void ConvertMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // Transpose flags should be easy to support, but we don't have a // GraphDef with them to test on at the moment. @@ -967,11 +999,12 @@ void ConvertMatMulOperator(const NodeDef& node, matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); + return tensorflow::Status::OK(); } -void ConvertConcatOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConcatOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { Operator* op = nullptr; if (node.op() == "Concat") { op = new TensorFlowConcatOperator; @@ -991,13 +1024,14 @@ void ConvertConcatOperator(const NodeDef& node, } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This method supports simple operators without additional attributes. template -void ConvertSimpleOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { @@ -1005,22 +1039,23 @@ void ConvertSimpleOperator(const NodeDef& node, } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This method supports simple operators without additional attributes. template -void ConvertSimpleOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, NumInputs); - ConvertSimpleOperator(node, tf_import_flags, model); +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); + return ConvertSimpleOperator(node, tf_import_flags, model); } -void ConvertMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Max"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowMaxOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1029,13 +1064,14 @@ void ConvertMaxOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertMinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMinOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Min"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowMinOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1044,12 +1080,12 @@ void ConvertMinOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } - -void ConvertUnsupportedOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertUnsupportedOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1072,15 +1108,16 @@ void ConvertUnsupportedOperator(const NodeDef& node, const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } + return tensorflow::Status::OK(); } -void ConvertStridedSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertStridedSliceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "StridedSlice"); // TODO(soroosh): The 4th input (strides) should be e optional, to be // consistent with TF. - CheckInputsCount(node, tf_import_flags, 4); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); auto* op = new StridedSliceOperator; for (const auto& input : node.input()) { @@ -1100,14 +1137,15 @@ void ConvertStridedSliceOperator(const NodeDef& node, : 0; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertPlaceholderOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertPlaceholderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); if (node.op() == "Placeholder") { - CheckInputsCount(node, tf_import_flags, 0); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0)); } auto& array = model->GetOrCreateArray(node.name()); if (node.attr().count("dtype")) { @@ -1132,17 +1170,20 @@ void ConvertPlaceholderOperator(const NodeDef& node, } } } + return tensorflow::Status::OK(); } -void ConvertNoOpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) {} +tensorflow::Status ConvertNoOpOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + return tensorflow::Status::OK(); +} -void ConvertCastOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertCastOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Cast"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT"); auto* op = new CastOperator; @@ -1151,27 +1192,31 @@ void ConvertCastOperator(const NodeDef& node, op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFloorOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFloorOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Floor"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); CHECK(data_type == DT_FLOAT); auto* op = new FloorOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertGatherOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertGatherOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Gather" || node.op() == "GatherV2"); - if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2); - if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3); + if (node.op() == "Gather") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + if (node.op() == "GatherV2") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64); auto* op = new GatherOperator; @@ -1181,13 +1226,14 @@ void ConvertGatherOperator(const NodeDef& node, // should read it an pass it on to the TF Lite Interpreter. op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertArgMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertArgMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "ArgMax"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; const auto output_type = HasAttr(node, "output_type") @@ -1201,13 +1247,14 @@ void ConvertArgMaxOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertResizeBilinearOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertResizeBilinearOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "ResizeBilinear"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new ResizeBilinearOperator; op->align_corners = false; @@ -1219,13 +1266,14 @@ void ConvertResizeBilinearOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchNormWithGlobalNormalizationOperator( +tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // TODO(ahentz): to really match tensorflow we need to add variance_epsilon // to the input, before feeding it into TensorFlowRsqrtOperator. @@ -1268,13 +1316,14 @@ void ConvertBatchNormWithGlobalNormalizationOperator( op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFusedBatchNormOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFusedBatchNormOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "FusedBatchNorm"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // Declare shortcuts for the inputs. const string& gamma_input = node.input(1); @@ -1320,13 +1369,14 @@ void ConvertFusedBatchNormOperator(const NodeDef& node, op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToBatchNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToBatchNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToBatchND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32); auto* op = new SpaceToBatchNDOperator; @@ -1335,13 +1385,14 @@ void ConvertSpaceToBatchNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchToSpaceNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBatchToSpaceNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BatchToSpaceND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32); auto* op = new BatchToSpaceNDOperator; @@ -1350,13 +1401,14 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertMeanOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMeanOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Mean"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new MeanOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1367,11 +1419,12 @@ void ConvertMeanOperator(const NodeDef& node, } else if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertSvdfOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSvdfOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Svdf"); const int input_size = GetInputsCount(node, tf_import_flags); QCHECK(input_size == 3 || input_size == 4) @@ -1394,14 +1447,15 @@ void ConvertSvdfOperator(const NodeDef& node, } op->rank = node.attr().at("Rank").i(); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This is just bare bones support to get the shapes to propagate. -void ConvertTransposeConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTransposeConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2DBackpropInput"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TransposeConvOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1465,14 +1519,14 @@ void ConvertTransposeConvOperator(const NodeDef& node, "Conv2DBackpropInput nodes."; } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - -void ConvertRangeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRangeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Range"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new RangeOperator; if (HasAttr(node, "Tidx")) { const auto dtype = toco::GetDataTypeAttr(node, "Tidx"); @@ -1485,11 +1539,12 @@ void ConvertRangeOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertStackOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertStackOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK((node.op() == "Stack") || (node.op() == "Pack")); auto* op = new StackOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1505,9 +1560,9 @@ void ConvertStackOperator(const NodeDef& node, op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't // be able to fully support such graphs, including performing inference, @@ -1518,7 +1573,7 @@ void ConvertStackOperator(const NodeDef& node, // such ops as RNN back-edges, which is technically incorrect (does not // allow representing the op's semantics) but good enough to get a // graph visualization. -void ConvertOperatorSpecialCasedAsRNNBackEdge( +tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { // At the moment, the only type of operator special-cased in this way is @@ -1531,6 +1586,23 @@ void ConvertOperatorSpecialCasedAsRNNBackEdge( rnn_state->set_discardable(true); rnn_state->set_state_array(node.name()); rnn_state->set_back_edge_source_array(node.input(0)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertShapeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Shape"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); + const auto out_type = + HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32; + CHECK(out_type == DT_INT64 || out_type == DT_INT32); + auto op = absl::make_unique(); + op->output_data_type = ConvertDataType(out_type); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); } void StripCaretFromArrayNames(Model* model) { @@ -1673,9 +1745,9 @@ bool InlineAllFunctions(GraphDef* graphdef) { return graph_modified; } -void ConvertTopKV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTopKV2Operator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); auto op = absl::make_unique(); op->inputs.push_back(node.input(0)); @@ -1685,22 +1757,23 @@ void ConvertTopKV2Operator(const NodeDef& node, model, node.name() + "k", {static_cast(GetIntAttr(node, "k"))}); op->inputs.push_back(k_array); } else { - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(1)); } // The op has two outputs. op->outputs.push_back(node.name()); op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicPartitionOperator( +tensorflow::Status ConvertDynamicPartitionOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { auto op = absl::make_unique(); CHECK(HasAttr(node, "num_partitions")); op->num_partitions = GetIntAttr(node, "num_partitions"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); CHECK_GT(op->num_partitions, 1); @@ -1709,11 +1782,12 @@ void ConvertDynamicPartitionOperator( op->outputs.push_back(node.name() + ":" + std::to_string(i)); } model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicStitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDynamicStitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { // The parallel and non-parallel variants are the same besides whether they // have a parallel loop; there are no behavioral differences. CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch"); @@ -1721,19 +1795,20 @@ void ConvertDynamicStitchOperator(const NodeDef& node, CHECK(HasAttr(node, "N")); op->num_partitions = GetIntAttr(node, "N"); // Expect all ID partitions + all value partitions. - CheckInputsCount(node, tf_import_flags, op->num_partitions * 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2)); for (int i = 0; i < op->num_partitions * 2; ++i) { op->inputs.push_back(node.input(i)); } op->outputs.push_back(node.name()); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertSparseToDenseOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSparseToDenseOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SparseToDense"); - CheckInputsCount(node, tf_import_flags, 4); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); auto* op = new SparseToDenseOperator; for (const string& input : node.input()) { @@ -1745,217 +1820,133 @@ void ConvertSparseToDenseOperator(const NodeDef& node, ? GetBoolAttr(node, "validate_indices") : true; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } } // namespace namespace internal { + +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap() { + return std::unordered_map({ + {"Add", ConvertSimpleOperator}, + {"AddN", ConvertSimpleOperator}, + {"All", ConvertSimpleOperator}, + {"ArgMax", ConvertArgMaxOperator}, + {"Assert", ConvertSimpleOperator}, + {"AvgPool", ConvertAvgPoolOperator}, + {"BatchMatMul", ConvertBatchMatMulOperator}, + {"BatchNormWithGlobalNormalization", + ConvertBatchNormWithGlobalNormalizationOperator}, + {"BatchToSpaceND", ConvertBatchToSpaceNDOperator}, + {"BiasAdd", ConvertBiasAddOperator}, + {"Cast", ConvertCastOperator}, + {"CheckNumerics", ConvertIdentityOperator}, + {"Concat", ConvertConcatOperator}, + {"ConcatV2", ConvertConcatOperator}, + {"Const", ConvertConstOperator}, + {"Conv2D", ConvertConvOperator}, + {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"DepthToSpace", ConvertDepthToSpaceOperator}, + {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, + {"Div", ConvertSimpleOperator}, + {"DynamicPartition", ConvertDynamicPartitionOperator}, + {"DynamicStitch", ConvertDynamicStitchOperator}, + {"Equal", ConvertSimpleOperator}, + {"Exp", ConvertSimpleOperator}, + {"ExpandDims", ConvertSimpleOperator}, + {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs}, + {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars}, + {"Fill", ConvertSimpleOperator}, + {"Floor", ConvertFloorOperator}, + {"FloorDiv", ConvertSimpleOperator}, + {"FloorMod", ConvertSimpleOperator}, + {"FusedBatchNorm", ConvertFusedBatchNormOperator}, + {"Gather", ConvertGatherOperator}, + {"GatherV2", ConvertGatherOperator}, + {"Greater", ConvertSimpleOperator}, + {"GreaterEqual", + ConvertSimpleOperator}, + {"Identity", ConvertIdentityOperator}, + {"LRN", ConvertLRNOperator}, + {"LegacyFedInput", ConvertPlaceholderOperator}, + {"Less", ConvertSimpleOperator}, + {"LessEqual", ConvertSimpleOperator}, + {"Log", ConvertSimpleOperator}, + {"Log", ConvertSimpleOperator}, + {"LogSoftmax", ConvertSimpleOperator}, + {"MatMul", ConvertMatMulOperator}, + {"Max", ConvertMaxOperator}, + {"MaxPool", ConvertMaxPoolOperator}, + {"Maximum", ConvertSimpleOperator}, + {"Mean", ConvertMeanOperator}, + {"Merge", ConvertSimpleOperator}, + {"Min", ConvertMinOperator}, + {"Minimum", ConvertSimpleOperator}, + {"Mul", ConvertSimpleOperator}, + {"Neg", ConvertSimpleOperator}, + {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, + {"NoOp", ConvertNoOpOperator}, + {"NotEqual", ConvertSimpleOperator}, + {"Pack", ConvertStackOperator}, + {"Pad", ConvertSimpleOperator}, + {"PadV2", ConvertSimpleOperator}, + {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, + {"Placeholder", ConvertPlaceholderOperator}, + {"PlaceholderWithDefault", ConvertIdentityOperator}, + {"Pow", ConvertSimpleOperator}, + {"RandomUniform", ConvertRandomUniform}, + {"Range", ConvertRangeOperator}, + {"Rank", ConvertSimpleOperator}, + {"RealDiv", ConvertSimpleOperator}, + {"Relu", ConvertSimpleOperator}, + {"Relu6", ConvertSimpleOperator}, + {"Reshape", ConvertSimpleOperator}, + {"ResizeBilinear", ConvertResizeBilinearOperator}, + {"Rsqrt", ConvertSimpleOperator}, + {"Select", ConvertSimpleOperator}, + {"Shape", ConvertShapeOperator}, + {"Sigmoid", ConvertSimpleOperator}, + {"Sin", ConvertSimpleOperator}, + {"Slice", ConvertSimpleOperator}, + {"Softmax", ConvertSoftmaxOperator}, + {"SpaceToBatchND", ConvertSpaceToBatchNDOperator}, + {"SpaceToDepth", ConvertSpaceToDepthOperator}, + {"SparseToDense", ConvertSparseToDenseOperator}, + {"Split", ConvertSplitOperator}, + {"Sqrt", ConvertSimpleOperator}, + {"Square", ConvertSimpleOperator}, + {"Squeeze", ConvertSqueezeOperator}, + {"Stack", ConvertStackOperator}, + {"StopGradient", ConvertIdentityOperator}, + {"StridedSlice", ConvertStridedSliceOperator}, + {"Sub", ConvertSimpleOperator}, + {"Sum", ConvertSumOperator}, + {"Svdf", ConvertSvdfOperator}, + {"Switch", ConvertSwitchOperator}, + {"Tanh", ConvertSimpleOperator}, + {"Tile", ConvertSimpleOperator}, + {"TopK", ConvertTopKV2Operator}, + {"TopKV2", ConvertTopKV2Operator}, + {"Transpose", ConvertSimpleOperator}, + }); +} + tensorflow::Status ImportTensorFlowNode( const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, Model* model) { - // TODO(ahentz): Historically these functions all CHECK-fail on error. We've - // been slowly converting them to return Status. - if (node.op() == "Const") { - return ConvertConstOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2D") { - return ConvertConvOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2DBackpropInput") { - ConvertTransposeConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthwiseConv2dNative") { - ConvertDepthwiseConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthToSpace") { - ConvertDepthToSpaceOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToDepth") { - ConvertSpaceToDepthOperator(node, tf_import_flags, model); - } else if (node.op() == "BiasAdd") { - ConvertBiasAddOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu6") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sigmoid") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Tanh") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "MaxPool") { - ConvertMaxPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "AvgPool") { - ConvertAvgPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "Reshape") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "BatchMatMul") { - ConvertBatchMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "MatMul") { - ConvertMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || - node.op() == "StopGradient") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxVars") { - ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxArgs") { - ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); - } else if (node.op() == "Neg") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Rsqrt") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Squeeze") { - ConvertSqueezeOperator(node, tf_import_flags, model); - } else if (node.op() == "Sqrt") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Square") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Add") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "AddN") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Mul") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sub") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sum") { - ConvertSumOperator(node, tf_import_flags, model); - } else if (node.op() == "Tile") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Concat" || node.op() == "ConcatV2") { - ConvertConcatOperator(node, tf_import_flags, model); - } else if (node.op() == "LRN") { - ConvertLRNOperator(node, tf_import_flags, model); - } else if (node.op() == "Softmax") { - ConvertSoftmaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Log") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "LogSoftmax") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "All") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Assert") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Less") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "LessEqual") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Greater") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "GreaterEqual") { - ConvertSimpleOperator( - node, tf_import_flags, model); - } else if (node.op() == "Max") { - ConvertMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Min") { - ConvertMinOperator(node, tf_import_flags, model); - } else if (node.op() == "Maximum") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Minimum") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Merge") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Pad") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "PadV2") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "StridedSlice") { - ConvertStridedSliceOperator(node, tf_import_flags, model); - } else if (node.op() == "Shape") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Slice") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Split") { - ConvertSplitOperator(node, tf_import_flags, model); - } else if (node.op() == "Switch") { - ConvertSwitchOperator(node, tf_import_flags, model); - } else if (node.op() == "Placeholder") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "PlaceholderWithDefault") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "LegacyFedInput") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "NoOp") { - ConvertNoOpOperator(node, tf_import_flags, model); - } else if (node.op() == "Cast") { - ConvertCastOperator(node, tf_import_flags, model); - } else if (node.op() == "Floor") { - ConvertFloorOperator(node, tf_import_flags, model); - } else if (node.op() == "Gather" || node.op() == "GatherV2") { - ConvertGatherOperator(node, tf_import_flags, model); - } else if (node.op() == "ResizeBilinear") { - ConvertResizeBilinearOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchNormWithGlobalNormalization") { - ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags, - model); - } else if (node.op() == "FusedBatchNorm") { - ConvertFusedBatchNormOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToBatchND") { - ConvertSpaceToBatchNDOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchToSpaceND") { - ConvertBatchToSpaceNDOperator(node, tf_import_flags, model); - } else if (node.op() == "Mean") { - ConvertMeanOperator(node, tf_import_flags, model); - } else if (node.op() == "Svdf") { - ConvertSvdfOperator(node, tf_import_flags, model); - } else if (node.op() == "NextIteration") { - ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); - } else if (node.op() == "ExpandDims") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Fill") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorDiv") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorMod") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Range") { - ConvertRangeOperator(node, tf_import_flags, model); - } else if (node.op() == "Rank") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Stack" || node.op() == "Pack") { - ConvertStackOperator(node, tf_import_flags, model); - } else if (node.op() == "Transpose") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "ArgMax") { - ConvertArgMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Exp") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "TopK" || node.op() == "TopKV2") { - ConvertTopKV2Operator(node, tf_import_flags, model); - } else if (node.op() == "DynamicPartition") { - ConvertDynamicPartitionOperator(node, tf_import_flags, model); - } else if (node.op() == "DynamicStitch" || - node.op() == "ParallelDynamicStitch") { - ConvertDynamicStitchOperator(node, tf_import_flags, model); - } else if (node.op() == "RandomUniform") { - ConvertRandomUniform(node, tf_import_flags, model); - } else if (node.op() == "Sin") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Log") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Select") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "SparseToDense") { - ConvertSparseToDenseOperator(node, tf_import_flags, model); - } else if (node.op() == "Equal") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "NotEqual") { - ConvertSimpleOperator(node, tf_import_flags, - model); + const TensorFlowImportFlags& tf_import_flags, Model* model, + const ConverterMapType& converter_map) { + auto converter = converter_map.find(node.op()); + if (converter == converter_map.end()) { + return ConvertUnsupportedOperator(node, tf_import_flags, model); } else { - ConvertUnsupportedOperator(node, tf_import_flags, model); + return converter->second(node, tf_import_flags, model); } - return tensorflow::Status::OK(); } } // namespace internal @@ -1981,10 +1972,13 @@ std::unique_ptr ImportTensorFlowGraphDef( } Model* model = new Model; + const internal::ConverterMapType& converter_map = + internal::GetTensorFlowNodeConverterMap(); for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); - auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model); + auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model, + converter_map); CHECK(status.ok()) << status.error_message(); } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index d18c329a43411236f8fd5446998c168803b9373a..90e6f698efee6a6a32da18a658e72c3e8b6550c0 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -36,8 +36,14 @@ using tensorflow::NodeDef; using tensorflow::Status; namespace internal { +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap(); Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, - Model*); + Model*, const ConverterMapType&); } // namespace internal namespace { @@ -105,8 +111,9 @@ class ShapeImportTest : public ::testing::TestWithParam { Status ImportNode(const NodeDef& node) { Model model; - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), - &model); + const auto converter = internal::GetTensorFlowNodeConverterMap(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model, + converter); } }; diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 7bdec47aa9c1a960d0324c5f6a4b19f69cd056b2..abe0bf3c54460709dc67a4d5835df77ca8a83575 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ #define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#include #include #include #include @@ -32,7 +33,7 @@ namespace toco { using tflite::QuantizationParams; -enum class OperatorType { +enum class OperatorType : uint8 { kNone, // General-purpose neural network operators. kAdd, @@ -96,38 +97,38 @@ enum class OperatorType { // Special operators used for importing TensorFlow nodes. // The general intent is to have some graph transformation either // drop them or rewrite them as general-purpose operators. - kTensorFlowAll, - kTensorFlowAssert, - kTensorFlowConcat, - kTensorFlowConcatV2, - kTensorFlowGreater, - kTensorFlowGreaterEqual, - kTensorFlowIdentity, - kTensorFlowLess, - kTensorFlowLessEqual, - kTensorFlowMax, - kTensorFlowMaximum, - kTensorFlowMin, - kTensorFlowMinimum, - kTensorFlowMatMul, - kTensorFlowMerge, + kAll, + kAssert, + kConcat, + kConcatV2, + kGreater, + kGreaterEqual, + kIdentity, + kLess, + kLessEqual, + kMax, // Reduction Max + kMaximum, // Element-wise Maximum + kMin, // Reduction Min + kMinimum, // Element-wise Minimum + kMatMul, + kMerge, kNeg, - kTensorFlowReshape, - kTensorFlowRsqrt, - kTensorFlowShape, - kTensorFlowSplit, - kTensorFlowSqrt, - kTensorFlowSquare, - kTensorFlowSum, - kTensorFlowSwitch, - kTensorFlowTile, + kReshape, + kRsqrt, + kShape, + kSplit, + kSqrt, + kSquare, + kSum, + kSwitch, + kTile, kTranspose, kTopK_V2, kDynamicPartition, kDynamicStitch, // An unsupported TF operation. It's only needed to be able to represent TF // graph internally and is expected to be dropped by graph transformations. - kTensorFlowUnsupported, + kUnsupported, // Finally, TensorFlow uses different conventions for axes ordering, // see AxesOrder, and this cannot always be resolved at the time of importing // nodes, as TensorFlow parameters may be constant-expression subgraphs @@ -136,8 +137,9 @@ enum class OperatorType { kReorderAxes, kSelect, kSparseToDense, - kTensorFlowEqual, - kTensorFlowNotEqual, + kEqual, + kNotEqual, + kPow, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -160,21 +162,22 @@ enum class AxesOrder { // The type of the scalars in an array. // Note that the type does not by itself tell whether the values in the array -// are real (are literally interpreted as real numbers) or quantized (only -// acquire a meaning as real numbers in conjunction with QuantizationParams). +// are non-quantized (can be accessed directly) or quantized (must be +// interpreted in conjunction with QuantizationParams). // // In practice though: -// float values are always real +// float values are never quantized // uint8 values are always quantized -// int32 values are either real or quantized (depending on whether +// int32 values are sometimes quantized (depending on whether // QuantizationParams are present). -// other types are unused at the moment. +// complex values are never quantized +// other types are never quantized at the moment. // // kNone means that we don't know the data type yet, or that we don't care // because we'll be dropping the array anyway (e.g. some exotic array types // may be involved only in debug-only subgraphs that we may not be interested // in actually supporting). -enum class ArrayDataType { +enum class ArrayDataType : uint8 { kNone, // 0 kBool, kFloat, @@ -186,7 +189,8 @@ enum class ArrayDataType { kUint32, kInt64, kUint64, // 10 - kString + kString, + kComplex64, }; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type @@ -240,6 +244,10 @@ template <> struct DataTypeImpl { typedef string Type; }; +template <> +struct DataTypeImpl { + typedef std::complex Type; +}; template using DataType = typename DataTypeImpl::Type; @@ -433,7 +441,8 @@ struct SpaceToDepthOperator : Operator { // input activations as a matrix, followed by a MatMul node. struct FullyConnectedOperator : Operator { FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {} - bool experimental_shuffled_weights = false; + FullyConnectedWeightsFormat weights_format = + FullyConnectedWeightsFormat::kDefault; }; // Dequantization operator, converting a quantized array of integers with @@ -801,7 +810,7 @@ struct DivOperator : Operator { // // TensorFlow equivalent: Identity struct TensorFlowIdentityOperator : Operator { - TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} + TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {} }; // Batch matrix multiplication operator. This comes from the (deprecated) @@ -827,7 +836,7 @@ struct BatchMatMulOperator : Operator { // // TensorFlow equivalent: MatMul struct TensorFlowMatMulOperator : Operator { - TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {} + TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {} }; // Padding operator. Pads a tensor with zeros. @@ -961,7 +970,7 @@ struct StridedSliceOperator : Operator { // TensorFlow equivalent: Reshape --- except that we only support a special case // here, where the output shape is a matrix (2D) shape. struct TensorFlowReshapeOperator : Operator { - TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {} + TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {} std::vector shape; }; @@ -1131,7 +1140,7 @@ struct SelectOperator : Operator { // // TensorFlow equivalent: Rsqrt struct TensorFlowRsqrtOperator : Operator { - TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {} + TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {} }; // Stacks a list of rank-R tensors into one rank-(R+1) tensor. @@ -1157,10 +1166,10 @@ struct StackOperator : Operator { // This operation outputs a 1-D integer tensor representing the shape of // the input. // -// TensorFlow equivalent: Shape. We currently assume that the output is int32 -// and not int64. The output type could be stored herein. +// TensorFlow equivalent: Shape. struct TensorFlowShapeOperator : Operator { - TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {} + TensorFlowShapeOperator() : Operator(OperatorType::kShape) {} + ArrayDataType output_data_type = ArrayDataType::kInt32; }; // Element-wise square-root (x^0.5) operator. @@ -1170,7 +1179,7 @@ struct TensorFlowShapeOperator : Operator { // // TensorFlow equivalent: Sqrt struct TensorFlowSqrtOperator : Operator { - TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {} + TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {} }; // Element-wise square (x*x) operator. @@ -1180,7 +1189,7 @@ struct TensorFlowSqrtOperator : Operator { // // TensorFlow equivalent: Square struct TensorFlowSquareOperator : Operator { - TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {} + TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {} }; // Transposes a tensor. @@ -1208,16 +1217,14 @@ struct SubOperator : Operator { SubOperator() : Operator(OperatorType::kSub) {} }; -// Global sum reduction: computes the sum of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Sum reduction: computes the sum of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Sum --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Sum struct TensorFlowSumOperator : Operator { - TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {} + TensorFlowSumOperator() : Operator(OperatorType::kSum) {} bool keep_dims = false; }; @@ -1227,7 +1234,7 @@ struct TensorFlowSumOperator : Operator { // inputs[0]: required: the input array // inputs[1]: required: int array with length of rank(input[0]) struct TensorFlowTileOperator : Operator { - TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {} + TensorFlowTileOperator() : Operator(OperatorType::kTile) {} }; // TensorFlow Slice equivalent. Refer to TensorFlow documentation for details. @@ -1242,7 +1249,7 @@ struct SliceOperator : Operator { // Not fully supported, just a placeholder to handle TensorFlow graphs and // support graph transformations to other operator types by matching sub-graphs. struct TensorFlowSplitOperator : Operator { - TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {} + TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {} int num_split = 0; }; @@ -1253,7 +1260,7 @@ struct TensorFlowSplitOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatOperator : Operator { - TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {} + TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {} }; // TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for @@ -1264,7 +1271,7 @@ struct TensorFlowConcatOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatV2Operator : Operator { - TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {} + TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {} }; // TensorFlow Merge equivalent. Refer to TensorFlow documentation for details. @@ -1280,7 +1287,7 @@ struct TensorFlowConcatV2Operator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowMergeOperator : Operator { - TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {} + TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {} }; // TensorFlow Switch equivalent. Refer to TensorFlow documentation for details. @@ -1303,7 +1310,7 @@ struct TensorFlowMergeOperator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowSwitchOperator : Operator { - TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {} + TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {} }; // TensorFlow All equivalent. Refer to TensorFlow documentation for details. @@ -1312,7 +1319,7 @@ struct TensorFlowSwitchOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowAllOperator : Operator { - TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {} + TensorFlowAllOperator() : Operator(OperatorType::kAll) {} }; // TensorFlow Assert equivalent. Refer to TensorFlow documentation for details. @@ -1320,7 +1327,7 @@ struct TensorFlowAllOperator : Operator { // support graph transformations to other operator types by matching sub-graphs. // Typically, we just drop Assert nodes. struct TensorFlowAssertOperator : Operator { - TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {} + TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1329,7 +1336,7 @@ struct TensorFlowAssertOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessOperator : Operator { - TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {} + TensorFlowLessOperator() : Operator(OperatorType::kLess) {} }; // TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for @@ -1339,8 +1346,7 @@ struct TensorFlowLessOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessEqualOperator : Operator { - TensorFlowLessEqualOperator() - : Operator(OperatorType::kTensorFlowLessEqual) {} + TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1349,7 +1355,7 @@ struct TensorFlowLessEqualOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterOperator : Operator { - TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {} + TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {} }; // TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for @@ -1359,8 +1365,7 @@ struct TensorFlowGreaterOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterEqualOperator : Operator { - TensorFlowGreaterEqualOperator() - : Operator(OperatorType::kTensorFlowGreaterEqual) {} + TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {} }; // TensorFlow Equal equivalent. Refer to TensorFlow documentation for @@ -1370,13 +1375,13 @@ struct TensorFlowGreaterEqualOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowEqualOperator : Operator { - TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {} + TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {} }; // TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for // details. struct TensorFlowNotEqualOperator : Operator { - TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {} + TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {} }; // Global max reduction: computes the max of all of entries in the input array. @@ -1388,7 +1393,7 @@ struct TensorFlowNotEqualOperator : Operator { // TensorFlow equivalent: Max --- except that we only support the special case // of global reduction across all dimensions. struct TensorFlowMaxOperator : Operator { - TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {} + TensorFlowMaxOperator() : Operator(OperatorType::kMax) {} bool keep_dims = false; }; @@ -1401,7 +1406,7 @@ struct TensorFlowMaxOperator : Operator { // TensorFlow equivalent: Min --- except that we only support the special case // of global reduction across all dimensions. struct TensorFlowMinOperator : Operator { - TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {} + TensorFlowMinOperator() : Operator(OperatorType::kMin) {} bool keep_dims = false; }; @@ -1414,7 +1419,7 @@ struct TensorFlowMinOperator : Operator { // // TensorFlow equivalent: Maximum struct TensorFlowMaximumOperator : Operator { - TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {} + TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {} }; // Element-wise minimum operator. Currently it only supports scalar as @@ -1426,14 +1431,13 @@ struct TensorFlowMaximumOperator : Operator { // // TensorFlow equivalent: Minimum struct TensorFlowMinimumOperator : Operator { - TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {} + TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {} }; // General TF operation, unsupported by tf.mini. Expected to be dropped by // graph transformations. struct TensorFlowUnsupportedOperator : Operator { - TensorFlowUnsupportedOperator() - : Operator(OperatorType::kTensorFlowUnsupported) {} + TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {} // The original TF operation type. Used for diagnostic purposes. string tensorflow_op; @@ -1641,6 +1645,17 @@ struct SparseToDenseOperator : Operator { bool validate_indices; }; +// Pow operator: +// +// Inputs: +// Inputs[0]: required: A tensor. +// Inputs[1]: required: A tensor. +// +// TensorFlow equivalent: Pow. +struct PowOperator : Operator { + PowOperator() : Operator(OperatorType::kPow) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 4c9f1aa4b0274b5123bb3baa9b9fca1463bda4c3..06072d1fcb0612ed8193b3a0be1317923fe95bcc 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -74,10 +74,10 @@ bool ParseModelFlagsFromCommandLineFlags( "height, input array width, input array depth."), Flag("batch_size", parsed_flags.batch_size.bind(), parsed_flags.batch_size.default_value(), - "Batch size for the model. Replaces the first dimension of an " - "input size array if undefined. Use only with SavedModels when " - "--input_shapes flag is not specified. Always use --input_shapes " - "flag with frozen graphs."), + "Deprecated. Batch size for the model. Replaces the first dimension " + "of an input size array if undefined. Use only with SavedModels " + "when --input_shapes flag is not specified. Always use " + "--input_shapes flag with frozen graphs."), Flag("input_data_type", parsed_flags.input_data_type.bind(), parsed_flags.input_data_type.default_value(), "Deprecated: use --input_data_types instead. Input array type, if " diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h index f5de5a5781a5304634642680e6a3cef60e7b844b..207f2c1706ef4cc12572e381c38f61a504ece232 100644 --- a/tensorflow/contrib/lite/toco/runtime/types.h +++ b/tensorflow/contrib/lite/toco/runtime/types.h @@ -24,6 +24,7 @@ namespace toco { // TODO(ahentz): These are just stopgaps for now, untils we move all // the code over to tflite. using tflite::Dims; +using tflite::FullyConnectedWeightsFormat; using tflite::FusedActivationFunctionType; using tflite::RequiredBufferSizeForDims; diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index e1025c66642d2860c5916bf7625f1c0403c9901c..a02f90988b2863900b6a735fd69aa1975a762338 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -24,6 +24,7 @@ cc_library( deps = [ ":types", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 7ba2603a952f6611e987901b735e9d4212f014ea..19722468079a32b76f6952db6ca818da470a03ac 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -49,7 +49,7 @@ details::OperatorKey GetOperatorKey( const ::toco::Operator& op, const std::map>& ops_by_type) { string custom_code; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast(op); custom_code = unsupported_op.tensorflow_op; @@ -211,7 +211,7 @@ Offset>> ExportOperatorCodes( ordered_opcodes[op_index] = CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); } else { - // This could be a kTensorFlowUnsupported, in which case we should be + // This could be a kUnsupported, in which case we should be // able to retrieve the original Tensorflow name from the OperatorKey, or // this could be a proper TOCO operator that is completely unknown to TF // Lite. @@ -268,7 +268,7 @@ Offset>> ExportOperators( : tflite_op_it->second.get(); // This is a custom op unless we can find it in ops_by_type, and even then - // it could be a custom op (such as kTensorFlowUnsupported). + // it could be a custom op (such as kUnsupported). auto options = Options::Custom(0); std::vector mutating_input_variables; diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 098d2163e6c2fe26f3cb9cdf9959df62a1a4baf0..58ea5c725c378827aac79f2a5a2cdca59ccc0162 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -45,7 +45,7 @@ namespace details { using TensorsMap = std::unordered_map; // A key to identify an operator. -// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to +// Only when `type` is `kUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { OperatorKey(OperatorType type, const std::string& custom_code, int version) diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 409e7d72a57076ec2832c5d12b52829477624f74..d1fdbcb8e9131e1d65fa32ca0395bbc17b2014e7 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -73,8 +73,8 @@ TEST_F(ExportTest, LoadOperatorsMap) { EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); - EXPECT_EQ(3, operators[details::OperatorKey( - OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]); + EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported, + "MyCrazyOp", 1)]); } TEST_F(ExportTest, Export) { diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index cb44a5e6d7356a1cf5597bbe48565c5b1e1949a6..1dd4915b31413e5afb04b45ee7c4893a2eded66d 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -124,7 +124,7 @@ void ImportOperators( new_op = ops_by_name.at(effective_opname) ->Deserialize(input_op->builtin_options(), input_op->custom_options()); - if (new_op->type == OperatorType::kTensorFlowUnsupported) { + if (new_op->type == OperatorType::kUnsupported) { auto* unsupported_op = static_cast(new_op.get()); unsupported_op->tensorflow_op = opname; @@ -221,6 +221,8 @@ std::unique_ptr Import(const ModelFlags& model_flags, model.get()); ImportIOTensors(*input_model, tensors_table, model.get()); + UndoWeightsShuffling(model.get()); + return model; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index a0fbb58acafbea72a0678754d1a6ae4275580e44..7e55ae92bd57447cc821b21b40ba289cb484a9ed 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/operator.h" +// TODO(ycling): Consider refactoring to extract the LSTM definition out of +// graph_transformation module. +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" @@ -311,16 +314,47 @@ class FullyConnected flatbuffers::FlatBufferBuilder* builder) const override { auto activation_function = ActivationFunction::Serialize(op.fused_activation_function); - return ::tflite::CreateFullyConnectedOptions(*builder, activation_function); + ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format; + switch (op.weights_format) { + case FullyConnectedWeightsFormat::kDefault: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + break; + case FullyConnectedWeightsFormat::kShuffled4x16Int8: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + } + return ::tflite::CreateFullyConnectedOptions(*builder, activation_function, + tflite_weights_format); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); + switch (options.weights_format()) { + case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT: + op->weights_format = FullyConnectedWeightsFormat::kDefault; + break; + case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + op->weights_format = FullyConnectedWeightsFormat::kDefault; + } } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& fc_op = static_cast(op); + return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1 + : 2; + } }; class Gather : public BuiltinOperator(op); + std::vector mutating_input_variables(op.inputs.size(), false); switch (lstm_op.kernel_type) { - case LstmCellOperator::KERNEL_FULL: - // TODO(ycling): Change the full kernel to use the new variable tensor - // design. This requires moving the state tensors from output to input. - return std::vector(); + case LstmCellOperator::KERNEL_FULL: { + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + break; + } case LstmCellOperator::KERNEL_BASIC: { - std::vector mutating_input_variables(op.inputs.size(), false); mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; - return mutating_input_variables; + break; } } + return mutating_input_variables; + } +}; + +class Mean : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); } + + int GetVersion(const Operator& op) const override { return 1; } }; -class Mean : public BuiltinOperator { +class Sum + : public BuiltinOperator { public: using BuiltinOperator::BuiltinOperator; flatbuffers::Offset WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateMeanOptions(*builder, op.keep_dims); + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); } void ReadOptions(const TfLiteOptions& options, @@ -894,6 +949,26 @@ class ExpandDims int GetVersion(const Operator& op) const override { return 1; } }; +class Shape + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateShapeOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.out_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -954,6 +1029,20 @@ class TensorFlowUnsupported : public BaseOperator { fbb->Bool(key, attr.b()); has_valid_attr = true; break; + case tensorflow::AttrValue::kList: + if (attr.list().i_size() > 0) { + auto start = fbb->StartVector(key); + for (const int64_t v : attr.list().i()) { + fbb->Add(v); + } + fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); + has_valid_attr = true; + } else { + LOG(WARNING) + << "Ignoring unsupported type in list attribute with key '" + << key << "'"; + } + break; default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -990,6 +1079,14 @@ class TensorFlowUnsupported : public BaseOperator { case flexbuffers::TYPE_BOOL: (*attr)[key].set_b(value.AsBool()); break; + case flexbuffers::TYPE_VECTOR_INT: { + auto* list = (*attr)[key].mutable_list(); + const auto& vector = value.AsTypedVector(); + for (size_t i = 0; i < vector.size(); i++) { + list->add_i(vector[i].AsInt64()); + } + break; + } default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -1048,8 +1145,8 @@ std::vector> BuildOperatorList() { ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); ops.emplace_back( new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); - ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, - OperatorType::kTensorFlowReshape)); + ops.emplace_back( + new Reshape(::tflite::BuiltinOperator_RESHAPE, OperatorType::kReshape)); ops.emplace_back( new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax)); ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH, @@ -1060,12 +1157,13 @@ std::vector> BuildOperatorList() { OperatorType::kTranspose)); ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); - ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT, - OperatorType::kTensorFlowSplit)); + ops.emplace_back( + new Split(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); ops.emplace_back( @@ -1077,27 +1175,27 @@ std::vector> BuildOperatorList() { ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); ops.emplace_back( - new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTensorFlowTile)); + new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, OperatorType::kExpandDims)); ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense)); + ops.emplace_back( + new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); // Custom Operators. ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); - ops.emplace_back(new TensorFlowUnsupported( - "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); + ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED", + OperatorType::kUnsupported)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. ops.emplace_back( new SimpleOperator("ADDN", OperatorType::kAddN)); - ops.emplace_back(new SimpleOperator( - "RSQRT", OperatorType::kTensorFlowRsqrt)); // Simple Operators. ops.emplace_back(new SimpleOperator( "DEQUANTIZE", OperatorType::kDequantize)); @@ -1119,29 +1217,34 @@ std::vector> BuildOperatorList() { ops.emplace_back(new SimpleOperator( "LOG_SOFTMAX", OperatorType::kLogSoftmax)); ops.emplace_back(new SimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum)); + "MAXIMUM", OperatorType::kMaximum)); // Element-wise Maximum ops.emplace_back(new SimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum)); + "MINIMUM", OperatorType::kMinimum)); // Element-wise Minimum ops.emplace_back(new SimpleOperator( - "GREATER", OperatorType::kTensorFlowGreater)); + "GREATER", OperatorType::kGreater)); ops.emplace_back(new SimpleOperator( - "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual)); - ops.emplace_back(new SimpleOperator( - "LESS", OperatorType::kTensorFlowLess)); + "GREATER_EQUAL", OperatorType::kGreaterEqual)); + ops.emplace_back( + new SimpleOperator("LESS", OperatorType::kLess)); ops.emplace_back(new SimpleOperator( - "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); + "LESS_EQUAL", OperatorType::kLessEqual)); ops.emplace_back(new SimpleOperator( - "EQUAL", OperatorType::kTensorFlowEqual)); + "EQUAL", OperatorType::kEqual)); ops.emplace_back(new SimpleOperator( - "NOT_EQUAL", OperatorType::kTensorFlowNotEqual)); + "NOT_EQUAL", OperatorType::kNotEqual)); ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back( new SimpleOperator("SELECT", OperatorType::kSelect)); ops.emplace_back( new SimpleOperator("SLICE", OperatorType::kSlice)); + ops.emplace_back(new SimpleOperator("POW", OperatorType::kPow)); // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); + ops.emplace_back( + new SimpleOperator("SQRT", OperatorType::kSqrt)); + ops.emplace_back(new SimpleOperator( + "RSQRT", OperatorType::kRsqrt)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 03bb20b3208196e964d950c0f0954d1fc0ba9e86..8b6808d3c78d8c51c1b33d09eb4082326100b028 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -112,20 +112,21 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax); CheckSimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum); + "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum CheckSimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum); - CheckSimpleOperator("LESS", - OperatorType::kTensorFlowLess); + "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum + CheckSimpleOperator("LESS", OperatorType::kLess); CheckSimpleOperator("NEG", OperatorType::kNeg); CheckSimpleOperator("SELECT", OperatorType::kSelect); CheckSimpleOperator("SLICE", OperatorType::kSlice); CheckSimpleOperator("SIN", OperatorType::kSin); - CheckSimpleOperator("EQUAL", - OperatorType::kTensorFlowEqual); - CheckSimpleOperator( - "NOT_EQUAL", OperatorType::kTensorFlowNotEqual); + CheckSimpleOperator("EQUAL", OperatorType::kEqual); + CheckSimpleOperator("NOT_EQUAL", + OperatorType::kNotEqual); CheckSimpleOperator("LOG", OperatorType::kLog); + CheckSimpleOperator("SQRT", OperatorType::kSqrt); + CheckSimpleOperator("RSQRT", OperatorType::kRsqrt); + CheckSimpleOperator("POW", OperatorType::kPow); } TEST_F(OperatorTest, BuiltinAdd) { @@ -254,7 +255,7 @@ TEST_F(OperatorTest, BuiltinReshape) { TensorFlowReshapeOperator op; op.shape = {1, 2, 4, 5, 8}; auto output_toco_op = SerializeAndDeserialize( - GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op); + GetOperator("RESHAPE", OperatorType::kReshape), op); EXPECT_EQ(op.shape, output_toco_op->shape); } @@ -277,8 +278,8 @@ TEST_F(OperatorTest, BuiltinSpaceToDepth) { TEST_F(OperatorTest, CustomSplit) { TensorFlowSplitOperator op; op.num_split = 123; - auto output_toco_op = SerializeAndDeserialize( - GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op); + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op); EXPECT_EQ(op.num_split, output_toco_op->num_split); } @@ -427,6 +428,14 @@ TEST_F(OperatorTest, BuiltinTransposeConv) { EXPECT_EQ(op.padding.type, output_toco_op->padding.type); } +TEST_F(OperatorTest, BuiltinShape) { + TensorFlowShapeOperator op; + op.output_data_type = ArrayDataType::kInt64; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op); + EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); +} + TEST_F(OperatorTest, BuiltinSparseToDense) { SparseToDenseOperator op; op.validate_indices = false; @@ -446,12 +455,17 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { (*attr)["str_attr"].set_s("Hello World"); (*attr)["int_attr"].set_i(17); (*attr)["bool_attr"].set_b(true); + { + auto* list = (*attr)["list_int_attr"].mutable_list(); + list->add_i(1); + list->add_i(20); + list->add_i(1LL << 40); + list->add_i(-(1LL << 40)); + } node_def.SerializeToString(&op.tensorflow_node_def); - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); @@ -460,15 +474,22 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); EXPECT_EQ(17, output_attr.at("int_attr").i()); EXPECT_EQ(true, output_attr.at("bool_attr").b()); + + { + const auto& list = output_attr.at("list_int_attr").list(); + ASSERT_EQ(4, list.i_size()); + EXPECT_EQ(1, list.i(0)); + EXPECT_EQ(20, list.i(1)); + EXPECT_EQ(1LL << 40, list.i(2)); + EXPECT_EQ(-(1LL << 40), list.i(3)); + } } TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index 42c5d7e8ebc3a7b90963a92843af616d9e6532d6..754f0b4b8c661355c99d9e5a86f2d7844414a303 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -100,6 +100,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_STRING; case ArrayDataType::kBool: return ::tflite::TensorType_BOOL; + case ArrayDataType::kComplex64: + return ::tflite::TensorType_COMPLEX64; default: // FLOAT32 is filled for unknown data types. // TODO(ycling): Implement type inference in TF Lite interpreter. @@ -123,6 +125,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kUint8; case ::tflite::TensorType_BOOL: return ArrayDataType::kBool; + case ::tflite::TensorType_COMPLEX64: + return ArrayDataType::kComplex64; default: LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; } @@ -147,6 +151,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyBuffer(array, builder); case ArrayDataType::kBool: return CopyBoolToBuffer(array, builder); + case ArrayDataType::kComplex64: + return CopyBuffer(array, builder); default: LOG(FATAL) << "Unhandled array data type."; } @@ -172,6 +178,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyBuffer(buffer, array); case ::tflite::TensorType_BOOL: return CopyBuffer(buffer, array); + case ::tflite::TensorType_COMPLEX64: + return CopyBuffer(buffer, array); default: LOG(FATAL) << "Unhandled tensor type."; } diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index 8c6ef95bfab0a5e9b410748eabf9570eec52c2e0..8e9f30ba3a6e6b98fa9c4237567b0797a5a797aa 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/types.h" +#include + #include #include @@ -71,7 +73,8 @@ TEST(DataType, SupportedTypes) { {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}, - {ArrayDataType::kBool, ::tflite::TensorType_BOOL}}; + {ArrayDataType::kBool, ::tflite::TensorType_BOOL}, + {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}}; for (auto x : testdata) { EXPECT_EQ(x.second, DataType::Serialize(x.first)); EXPECT_EQ(x.first, DataType::Deserialize(x.second)); @@ -171,6 +174,14 @@ TEST(DataBuffer, Bool) { ::testing::ElementsAre(true, false, true)); } +TEST(DataBuffer, Complex64) { + Array recovered = ToFlatBufferAndBack( + {std::complex(1.0f, 2.0f), std::complex(3.0f, 4.0f)}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(std::complex(1.0f, 2.0f), + std::complex(3.0f, 4.0f))); +} + TEST(Padding, All) { EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc index 8041aa9e7fbfdaf44134395fee4b2bb01633893a..0b460bd178a49cafefd3438b7ae1c38a07b2ab7c 100644 --- a/tensorflow/contrib/lite/toco/toco.cc +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" @@ -49,17 +48,6 @@ void CheckFrozenModelPermissions(const Arg& input_file) { << input_file.value() << ".\n"; } -// Checks the permissions of the SavedModel directory. -void CheckSavedModelPermissions(const Arg& savedmodel_directory) { - QCHECK(savedmodel_directory.specified()) - << "Missing required flag --savedmodel_directory.\n"; - QCHECK( - port::file::Exists(savedmodel_directory.value(), port::file::Defaults()) - .ok()) - << "Specified savedmodel_directory does not exist: " - << savedmodel_directory.value() << ".\n"; -} - // Reads the contents of the GraphDef from either the frozen graph file or the // SavedModel directory. If it reads the SavedModel directory, it updates the // ModelFlags and TocoFlags accordingly. @@ -69,24 +57,16 @@ void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, string* graph_def_contents) { port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); - bool has_input_file = parsed_toco_flags.input_file.specified(); - bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified(); - - // Ensure either input_file or savedmodel_directory flag has been set. - QCHECK_NE(has_input_file, has_savedmodel_dir) - << "Specify either input_file or savedmodel_directory flag.\n"; + // Ensure savedmodel_directory is not set. + QCHECK(!parsed_toco_flags.savedmodel_directory.specified()) + << "Use `tensorflow/contrib/lite/python/tflite_convert` script with " + << "SavedModel directories.\n"; // Checks the input file permissions and reads the contents. - if (has_input_file) { - CheckFrozenModelPermissions(parsed_toco_flags.input_file); - CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), - graph_def_contents, port::file::Defaults()) - .ok()); - } else { - CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory); - GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags, - model_flags, graph_def_contents); - } + CheckFrozenModelPermissions(parsed_toco_flags.input_file); + CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), + graph_def_contents, port::file::Defaults()) + .ok()); } void ToolMain(const ParsedTocoFlags& parsed_toco_flags, diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 87a1e429b928bf59cb14597980602953732a7659..c6d0a03452f7477841d7e68665baf32dff45f41c 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -41,7 +41,7 @@ bool ParseTocoFlagsFromCommandLineFlags( "extension."), Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(), parsed_flags.savedmodel_directory.default_value(), - "Full path to the directory containing the SavedModel."), + "Deprecated. Full path to the directory containing the SavedModel."), Flag("output_file", parsed_flags.output_file.bind(), parsed_flags.output_file.default_value(), "Output file. " @@ -55,9 +55,9 @@ bool ParseTocoFlagsFromCommandLineFlags( "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(), parsed_flags.savedmodel_tagset.default_value(), - "Comma-separated set of tags identifying the MetaGraphDef within " - "the SavedModel to analyze. All tags in the tag set must be " - "specified."), + "Deprecated. Comma-separated set of tags identifying the " + "MetaGraphDef within the SavedModel to analyze. All tags in the tag " + "set must be specified."), Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index ad4e94ded9f9730842a257e065d9aec2b1cbfac8..b4a9870d5834d1d5689d15ebc131ac0ead3e9850 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 21. +// Next ID to use: 26. message TocoFlags { // Input file format optional FileFormat input_format = 1; diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc deleted file mode 100644 index 26f55a66c729894a990258080e397bb42ea98a13..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/strings/numbers.h" -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace toco { -namespace { - -// Loads a SavedModel from the directory specified in parsed_toco_flags. -// Returns a SavedModelBundle with the requested MetaGraphDef. -const tensorflow::SavedModelBundle* LoadSavedModel( - const ParsedTocoFlags& parsed_toco_flags) { - const string model_path = parsed_toco_flags.savedmodel_directory.value(); - QCHECK(tensorflow::MaybeSavedModelDirectory(model_path)) - << "Model is not saved in the supported SavedModel format.\n"; - - // Gets the tags identifying the MetaGraphDef from the command line arguments. - string tags_str; - if (parsed_toco_flags.savedmodel_tagset.specified()) { - tags_str = parsed_toco_flags.savedmodel_tagset.value(); - } else { - tags_str = parsed_toco_flags.savedmodel_tagset.default_value(); - } - auto tags = absl::StrSplit(tags_str, ','); - - // Loads MetaGraphDef. - auto* bundle = new tensorflow::SavedModelBundle; - TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(), - tensorflow::RunOptions(), model_path, - tags, bundle)) - << "Failed to load exported model from " << model_path - << ". Ensure the model contains the required tags '" << tags_str - << "'.\n"; - return bundle; -} - -// Returns the array name without the postfix. -// -// e.g. reduces "input:0" to "input". -string GetArrayName(const string& name) { - const std::vector& names = absl::StrSplit(name, ':'); - return names[0]; -} - -// Returns the list of array names without the postfix sorted alphabetically. -std::set GetSortedNames(const std::unordered_set& names) { - std::vector final_names; - final_names.reserve(names.size()); - for (const auto& name : names) { - final_names.push_back(GetArrayName(name)); - } - return std::set(final_names.begin(), final_names.end()); -} - -// Gets the final shape after replacing the first dimension with batch size, if -// it is undefined (containing the value -1). Returns whether the shape is -// valid. -bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape, - int batch_size, - tensorflow::TensorShapeProto* final_shape) { - for (int idx = 0; idx < shape.dim().size(); ++idx) { - int64 final_dim = shape.dim()[idx].size(); - if (final_dim == -1) { - if (idx > 0) return false; - final_dim = batch_size; - } - final_shape->add_dim()->set_size(final_dim); - } - return true; -} - -// Updates the input arrays in ModelFlags to contain the shape of the array. -void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size, - ModelFlags* model_flags) { - // Build map of input array names to input arrays. - std::unordered_map input_data_map; - for (auto& input : *model_flags->mutable_input_arrays()) { - input_data_map[input.name()] = &input; - } - - // Adds shapes to the input arrays if the shape is valid. - for (const tensorflow::NodeDef& node_def : graph_def.node()) { - if (input_data_map.find(node_def.name()) != input_data_map.end()) { - const auto shape_it = node_def.attr().find("shape"); - if (shape_it != node_def.attr().end()) { - tensorflow::TensorShapeProto final_shape; - bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(), - batch_size, &final_shape); - - if (is_valid) { - auto* shape = input_data_map.at(node_def.name())->mutable_shape(); - QCHECK_EQ(shape->dims_size(), 0) - << "The shape for the input '" << node_def.name() - << "' was previously defined. For clarity please define inputs " - << "via --input_arrays and input_shapes flags.\n"; - for (const auto& dim : final_shape.dim()) { - shape->add_dims(dim.size()); - } - } - } - } - } - - // Checks all input arrays have a shape. - for (auto const& input : model_flags->input_arrays()) { - QCHECK(input.shape().dims_size() > 0) - << "A valid input shape was not found for input '" << input.name() - << "'. Please define via --input_arrays and --input_shapes flags.\n"; - } -} - -} // namespace - -void ParseMetaData(const tensorflow::GraphDef& graph_def, - const std::unordered_set& inputs, - const std::unordered_set& outputs, - const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags) { - if (!parsed_model_flags.input_arrays.specified()) { - const std::set sorted_inputs = GetSortedNames(inputs); - for (const auto& input_name : sorted_inputs) { - model_flags->add_input_arrays()->set_name(input_name); - } - } - - if (!parsed_model_flags.output_arrays.specified()) { - const std::set sorted_outputs = GetSortedNames(outputs); - for (const auto& output_name : sorted_outputs) { - model_flags->add_output_arrays(GetArrayName(output_name)); - } - } - - if (!parsed_model_flags.input_shapes.specified()) { - int batch_size = parsed_model_flags.batch_size.value(); - ProcessInputShapes(graph_def, batch_size, model_flags); - } - - if (!parsed_toco_flags.inference_type.specified()) { - toco_flags->set_inference_type(IODataType::FLOAT); - } -} - -// TODO(nupurgarg): Add top level tests. -void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags, - string* graph_def_contents) { - // Loads the MetaGraphDef within a SavedModelBundle. - auto bundle = LoadSavedModel(parsed_toco_flags); - - // Converts the MetaGraphDef to frozen GraphDef. - tensorflow::GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs, - &outputs)); - - // Reads the frozen GraphDef into a string. - QCHECK(frozen_graph_def.SerializeToString(graph_def_contents)) - << "Unable to generate serialized GraphDef.\n"; - - // Process inputs and outputs and metadata within GraphDef. - const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def(); - ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags, - parsed_model_flags, toco_flags, model_flags); -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h deleted file mode 100644 index 7a0fabd82d90131a3b2d28c757c08dcb0f9e3988..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ - -#include -#include - -#include "tensorflow/cc/tools/freeze_saved_model.h" -#include "tensorflow/contrib/lite/toco/args.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/types.pb.h" - -namespace toco { - -// Parses metadata into `toco_flags` and `model_flags`. -// -// Stores `inputs` as input_arrays and `outputs` as output_arrays in -// `model_flags`. Infers input_shapes from the GraphDef and stores it in -// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT -// and stores it in `toco_flags`. -void ParseMetaData(const tensorflow::GraphDef& graph_def, - const std::unordered_set& inputs, - const std::unordered_set& outputs, - const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags); - -// Generates a frozen graph from the SavedModel in the directory specified in -// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses -// metadata relating to the GraphDef into `toco_flags` and `model_flags`. -void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags, - string* graph_def_contents); - -} // namespace toco - -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc deleted file mode 100644 index 5e122afe65dc29abc85f142f4019aae5058ace51..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc +++ /dev/null @@ -1,274 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" -#include "absl/strings/str_join.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -#include -#include - -namespace toco { -namespace { - -using tensorflow::ops::Add; -using tensorflow::ops::Const; -using tensorflow::ops::FakeQuantWithMinMaxArgs; -using tensorflow::ops::Placeholder; - -class TocoSavedModelTest : public ::testing::Test { - protected: - // Calls functions to process cmdline arguments and calls ParseMetaData. - // ParseMetaData parses input_arrays, output_arrays, and gets metadata from - // SavedModel it is not defined in the cmdline arguments. - void ProcessGraphDefMetadata(const std::unordered_set& inputs, - const std::unordered_set& outputs, - const tensorflow::GraphDef& graph_def) { - ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_); - ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_); - ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_, - parsed_model_flags_, &toco_flags_, &model_flags_); - } - - // Gets the GraphDef from the SavedModelBundle and processes metadata. - void ProcessSavedModelMetadata(const std::unordered_set& inputs, - const std::unordered_set& outputs) { - const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def(); - ProcessGraphDefMetadata(inputs, outputs, graph_def); - } - - // Returns a GraphDef representing a simple float model with a single input. - tensorflow::GraphDef GetFloatGraphDef(const std::vector& shape) { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output input = - Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::PartialTensorShape(shape))); - tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); - tensorflow::Output add = Add(scope.WithOpName("add"), input, zero); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Returns a GraphDef representing a simple float model with two inputs. - tensorflow::GraphDef GetComplexFloatGraphDef() { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output inputA = - Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output inputB = - Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Returns a GraphDef representing a simple quantized model. - tensorflow::GraphDef GetQuantizedGraphDef() { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output input = - Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); - tensorflow::Output fake_quant = - FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero); - tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Gets the values in the input_arrays flag. - std::vector GetInputArrays() { - std::vector actual; - for (const auto& input : model_flags_.input_arrays()) { - actual.push_back(input.name()); - } - return actual; - } - - // Gets the values in the output_arrays flag. - std::vector GetOutputArrays() { - std::vector actual(model_flags_.output_arrays().begin(), - model_flags_.output_arrays().end()); - return actual; - } - - // Gets the shape of the given input array. - string GetInputShape(const string& input_array) { - for (const auto& input : model_flags_.input_arrays()) { - if (input.name() == input_array) { - std::vector dims; - for (int idx = 0; idx < input.shape().dims_size(); ++idx) { - dims.push_back(std::to_string(input.shape().dims(idx))); - } - return absl::StrJoin(dims, ","); - } - } - return ""; - } - - tensorflow::SavedModelBundle bundle_; - ParsedTocoFlags parsed_toco_flags_; - ParsedModelFlags parsed_model_flags_; - TocoFlags toco_flags_; - ModelFlags model_flags_; -}; - -// Tests if input_arrays, output_arrays, inference_type, and output_arrays are -// added to ModelFlags if they are not specified in cmdline arguments. -// Tests if the default batch size replaces a -1 in the first dimension. -TEST_F(TocoSavedModelTest, NoCmdLine) { - tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the order of input_arrays and output_arrays is deterministic when -// they are taken from the SavedModel. -TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) { - tensorflow::GraphDef graph_def = GetComplexFloatGraphDef(); - - // Note: The model does not have two outputs. However, the function does not - // need an accurate output_array list. This is only meant to test order. - ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add", "invalid"})); - EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1"); - EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if input_shapes is inferred when input_arrays is passed in via cmdline -// arguments. -TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) { - parsed_model_flags_.input_arrays.bind()("input"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1}); - - ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "2,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Ensures a failure occurs when input_shapes is defined without input_arrays. -TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) { - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); - - EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), - "failed: input_shapes.size\\(\\) == " - "model_flags->input_arrays_size\\(\\)"); -} - -// Tests if the cmdline values of input_arrays, input_shapes are used when -// specified with an empty GraphDef. -TEST_F(TocoSavedModelTest, InputArraysCmdLine) { - parsed_model_flags_.input_arrays.bind()("inputA,inputB"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - - ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"output0", "output1"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(GetInputShape("inputB"), "9,12"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the cmdline values of input_arrays, input_shapes are used when -// specified even if values exist within the GraphDef. -TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) { - parsed_model_flags_.input_arrays.bind()("inputA"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); - - ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the cmdline values of input_arrays, input_shapes, inference_type, -// and output_arrays are used when specified with an empty GraphDef. -TEST_F(TocoSavedModelTest, AllParamsCmdLine) { - parsed_model_flags_.input_arrays.bind()("inputA,inputB"); - parsed_model_flags_.output_arrays.bind()("outputA,outputB"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - parsed_toco_flags_.inference_type.bind()("FLOAT"); - - ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"outputA", "outputB"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(GetInputShape("inputB"), "9,12"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if a quantized graph gives the correct values assuming type is passed -// in via command line. -TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) { - parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8"); - tensorflow::GraphDef graph_def = GetQuantizedGraphDef(); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8); -} - -// Tests if the provided batch size replaces a -1 in the first dimension of -// input shape. -TEST_F(TocoSavedModelTest, MissingShapeParameterValid) { - parsed_model_flags_.batch_size.bind()(3); - tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "3,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Ensures a failure occurs if there is a -1 in a dimension aside from the first -// position of input shape. -TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) { - parsed_model_flags_.batch_size.bind()(3); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1}); - - EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), - "A valid input shape was not found for input 'input'."); -} - -} // namespace -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 3173d524b7fd043aeec72322875a39d2268ca3f6..a057dcef121a9a17b15d0b19ca908d12d89b0367 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -34,11 +34,11 @@ limitations under the License. namespace toco { namespace { -// CHECK-fails if the model contains a kTensorFlowUnsupported operation. +// CHECK-fails if the model contains a kUnsupported operation. void CheckUnsupportedOperations(const Model& model) { std::set unsupported_ops; for (auto& op : model.operators) { - if (op->type == OperatorType::kTensorFlowUnsupported) { + if (op->type == OperatorType::kUnsupported) { unsupported_ops.insert( static_cast(op.get()) ->tensorflow_op); @@ -134,6 +134,8 @@ bool SupportsPreallocatedWorkspace(FileFormat format) { return (format == TFLITE); } +bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; } + bool IsRealValued(toco::ArrayDataType type) { // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used // for quantized real-number values, and no other integer type is ever used @@ -335,6 +337,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) { new RemoveFinalDequantizeOp, ensure_safe_for_int8_kernels, }); + if (SupportsShuffledFCWeights(output_format)) { + RunGraphTransformations(model, "shuffling of FC weights", + {new ShuffleFCWeights}); + } } else { GraphTransformationsSet dequantization_transformations{new Dequantize}; // Dequantize creates FakeQuant nodes. We may want to discard diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 92bab5246cb85052b5e0216f1cb8a04736ae7a79..7dc1af9f1dc13fdc0f166e12bfd616cfdacc06c9 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -338,23 +338,23 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) HANDLE_OPERATORTYPENAME_CASE(Sin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) + HANDLE_OPERATORTYPENAME_CASE(All) + HANDLE_OPERATORTYPENAME_CASE(Assert) HANDLE_OPERATORTYPENAME_CASE(ExpandDims) HANDLE_OPERATORTYPENAME_CASE(Fill) HANDLE_OPERATORTYPENAME_CASE(FloorMod) HANDLE_OPERATORTYPENAME_CASE(FloorDiv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) + HANDLE_OPERATORTYPENAME_CASE(Greater) + HANDLE_OPERATORTYPENAME_CASE(GreaterEqual) + HANDLE_OPERATORTYPENAME_CASE(Identity) + HANDLE_OPERATORTYPENAME_CASE(Less) + HANDLE_OPERATORTYPENAME_CASE(LessEqual) + HANDLE_OPERATORTYPENAME_CASE(MatMul) + HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max + HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum + HANDLE_OPERATORTYPENAME_CASE(Merge) + HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min + HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) @@ -362,22 +362,22 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Stack) HANDLE_OPERATORTYPENAME_CASE(Range) HANDLE_OPERATORTYPENAME_CASE(Rank) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape) + HANDLE_OPERATORTYPENAME_CASE(Reshape) HANDLE_OPERATORTYPENAME_CASE(Squeeze) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape) + HANDLE_OPERATORTYPENAME_CASE(Rsqrt) + HANDLE_OPERATORTYPENAME_CASE(Shape) HANDLE_OPERATORTYPENAME_CASE(Slice) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch) + HANDLE_OPERATORTYPENAME_CASE(Split) + HANDLE_OPERATORTYPENAME_CASE(Sqrt) + HANDLE_OPERATORTYPENAME_CASE(Square) + HANDLE_OPERATORTYPENAME_CASE(Switch) HANDLE_OPERATORTYPENAME_CASE(Sub) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile) + HANDLE_OPERATORTYPENAME_CASE(Sum) + HANDLE_OPERATORTYPENAME_CASE(Tile) HANDLE_OPERATORTYPENAME_CASE(Transpose) HANDLE_OPERATORTYPENAME_CASE(TransposeConv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2) + HANDLE_OPERATORTYPENAME_CASE(Concat) + HANDLE_OPERATORTYPENAME_CASE(ConcatV2) HANDLE_OPERATORTYPENAME_CASE(Cast) HANDLE_OPERATORTYPENAME_CASE(Floor) HANDLE_OPERATORTYPENAME_CASE(Gather) @@ -388,14 +388,15 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) HANDLE_OPERATORTYPENAME_CASE(SparseToDense) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual) + HANDLE_OPERATORTYPENAME_CASE(Equal) + HANDLE_OPERATORTYPENAME_CASE(NotEqual) + HANDLE_OPERATORTYPENAME_CASE(Pow) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -403,7 +404,7 @@ const char* OperatorTypeName(OperatorType type) { } string HelpfulOperatorTypeName(const Operator& op) { - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { return toco::port::StringF( "(Unsupported TensorFlow op: %s)", static_cast(op).tensorflow_op); @@ -413,16 +414,20 @@ string HelpfulOperatorTypeName(const Operator& op) { bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { - case OperatorType::kConcatenation: - case OperatorType::kFakeQuant: - case OperatorType::kGather: - case OperatorType::kSlice: - case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: - case OperatorType::kTensorFlowSplit: - return false; - default: + case OperatorType::kAdd: + case OperatorType::kAveragePool: + case OperatorType::kBatchNormalization: + case OperatorType::kConv: + case OperatorType::kDepthwiseConv: + case OperatorType::kDiv: + case OperatorType::kFullyConnected: + case OperatorType::kL2Pool: + case OperatorType::kMaxPool: + case OperatorType::kMul: + case OperatorType::kSub: return true; + default: + return false; } } @@ -2196,4 +2201,51 @@ void UseArraysExtraInfo(Model* model, bool quantize_output) { } } +void UndoWeightsShuffling(Model* model) { + for (const auto& op : model->operators) { + if (op->type != toco::OperatorType::kFullyConnected) { + continue; + } + const auto& fc_op = static_cast(*op); + if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) { + continue; + } + const string& weights_name = fc_op.inputs[1]; + QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1); + auto& weights_array = model->GetArray(weights_name); + QCHECK(weights_array.data_type == ArrayDataType::kUint8); + auto& weights_data = + weights_array.GetMutableBuffer().data; + const auto& weights_shape = weights_array.shape(); + QCHECK_EQ(weights_shape.dimensions_count(), 2); + const int rows = weights_shape.dims(0); + const int cols = weights_shape.dims(1); + QCHECK_EQ(rows % 4, 0); + QCHECK_EQ(cols % 16, 0); + CHECK_EQ(rows * cols, weights_data.size()); + // Compute the de-shuffled weights + std::vector deshuffled_data(weights_data.size()); + uint8* shuffled_data_ptr = weights_data.data(); + for (int r = 0; r < rows; r += 4) { + for (int c = 0; c < cols; c += 16) { + for (int i = 0; i < 4; i++) { + uint8* deshuffled_data_ptr = + deshuffled_data.data() + (r + i) * cols + c; + for (int j = 0; j < 16; j++) { + uint8 shuffled_val = *shuffled_data_ptr++; + // Deshuffling isn't only about deshuffling the storage layout, + // it's also about undoing the flipping of the sign bit, which is + // performed on the shuffled weights. + uint8 deshuffled_val = shuffled_val ^ 0x80; + *deshuffled_data_ptr++ = deshuffled_val; + } + } + } + } + CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols); + // Switch this FC op to using the deshuffled weights. + weights_data = std::move(deshuffled_data); + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 7681ce9d39ec56f9447896682b52bd4efb1d0e54..5dbfa54fa0369676dce638aec171b409a468da9f 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -101,6 +101,8 @@ std::vector>::iterator FindOp(Model& model, const char* OperatorTypeName(OperatorType type); string HelpfulOperatorTypeName(const Operator& op); +// Whether the operator can be fused with an activation function. Note that this +// will return false by default for new operators; fusing support is opt-in. bool OperatorSupportsFusedActivation(OperatorType type); void DumpGraphvizVideoFrame(const Model& model); @@ -342,6 +344,11 @@ tensorflow::Status NumElements(const std::vector& shape, U* num_elements) { return tensorflow::Status::OK(); } +// A model file may have shuffled FC weights. +// When that happens, we want to de-shuffle them immediately on import, +// so that the rest of toco doesn't need to know about shuffled weights. +void UndoWeightsShuffling(Model* model); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index a683867374c8b8dcb274478adf6b5fa0691d1c5a..8609e5beddd200be4e5ebfe1fb2a79048e0e60ab 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -175,4 +175,10 @@ TEST(NumElementsTest, UnsignedInt64) { EXPECT_EQ(status.error_message(), kLargeTensorMessage); } +TEST(FusedActivationTest, DefaultsToUnfused) { + EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd)); + EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone)); + EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast(255))); +} + } // namespace toco diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD index 8857062c000201e1077469fc36e3bf2760924a30..183a545295f690decec47f1c31aa473667408a3d 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -66,6 +66,16 @@ cc_library( ], ) +cc_library( + name = "benchmark_params", + srcs = [ + "benchmark_params.cc", + "logging.h", + ], + hdrs = ["benchmark_params.h"], + copts = common_copts, +) + cc_library( name = "benchmark_model_lib", srcs = [ @@ -75,6 +85,7 @@ cc_library( hdrs = ["benchmark_model.h"], copts = common_copts, deps = [ + ":benchmark_params", ":command_line_flags", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md index c10826afff6d5569545d4b7df73c88d24d9dcd1a..93769305bde210b58f3b2cb668a9d8c1ad0ce396 100644 --- a/tensorflow/contrib/lite/tools/benchmark/README.md +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -3,7 +3,38 @@ ## Description A simple C++ binary to benchmark a TFLite model and its individual operators, -both on desktop machines and on Android. +both on desktop machines and on Android. The binary takes a TFLite model, +generates random inputs and then repeatedly runs the model for specified number +of runs. Aggregrate latency statistics are reported after running the benchmark. + +The instructions below are for running the binary on Desktop and Android, +for iOS please use the +[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). + +## Parameters + +The binary takes the following required parameters: + +* `graph`: `string` \ + The path to the TFLite model file. +* `input_layer`: `string` \ + The name of the input layer, this is typically the first layer of the model. +* `input_layer_shape`: `string` \ + The shape of the input layer. This is a comma separated string of the shape + of tensor of input layer. + +and the following optional parameters: + +* `num_threads`: `int` (default=1) \ + The number of threads to use for running TFLite interpreter. +* `warmup_runs`: `int` (default=1) \ + The number of warmup runs to do before starting the benchmark. +* `run_delay`: `float` (default=-1.0) \ + The delay in seconds between subsequent benchmark runs. Non-positive values + mean use no delay. +* `use_nnapi`: `bool` (default=false) \ + Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/). + This API is available on recent Android devices. ## To build/install/run @@ -44,7 +75,7 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp ``` adb shell /data/local/tmp/benchmark_model \ --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ - --input_layer="Placeholder" \ + --input_layer="input" \ --input_layer_shape="1,224,224,3" \ --num_threads=4 ``` @@ -70,6 +101,30 @@ bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ The MobileNet graph used as an example here may be downloaded from https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + +## Reducing variance between runs on Android. + +Most modern Android phones use [ARM big.LITTLE](https://en.wikipedia.org/wiki/ARM_big.LITTLE) +architecture where some cores are more power hungry but faster than other cores. +When running benchmarks on these phones there can be significant variance +between different runs of the benchmark. One way to reduce variance between runs +is to set the [CPU affinity](https://en.wikipedia.org/wiki/Processor_affinity) +before running the benchmark. On Android this can be done using the `taskset` +command. +E.g. for running the benchmark on big cores on Pixel 2 with a single thread one +can use the following command: + +``` +adb shell tasket f0 /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --input_layer="input" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=1 +``` + +where `f0` is the affinity mask for big cores on Pixel 2. +Note: The affinity mask varies with the device. + ## Profiling model operators The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this, compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED** diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc index a8a9a6112c1ec050be8d0bcfe9dc5f00df40d3ff..08648bcfe26365d180d984fde8f8e04b22eb45dd 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -48,6 +48,19 @@ namespace tflite { namespace benchmark { using tensorflow::Stat; +BenchmarkParams BenchmarkModel::DefaultParams() { + BenchmarkParams params; + params.AddParam("num_runs", BenchmarkParam::Create(50)); + params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); + params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("benchmark_name", BenchmarkParam::Create("")); + params.AddParam("output_prefix", BenchmarkParam::Create("")); + params.AddParam("warmup_runs", BenchmarkParam::Create(1)); + return params; +} + +BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {} + void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { auto inference_us = results.inference_time_us(); auto init_us = results.startup_latency_us(); @@ -60,24 +73,29 @@ void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { std::vector BenchmarkModel::GetFlags() { return { - Flag("num_runs", ¶ms_.num_runs, "number of runs"), - Flag("run_delay", ¶ms_.run_delay, "delay between runs in seconds"), - Flag("num_threads", ¶ms_.num_threads, "number of threads"), - Flag("benchmark_name", ¶ms_.benchmark_name, "benchmark name"), - Flag("output_prefix", ¶ms_.output_prefix, "benchmark output prefix"), - Flag("warmup_runs", ¶ms_.warmup_runs, - "how many runs to initialize model"), + CreateFlag("num_runs", ¶ms_, "number of runs"), + CreateFlag("run_delay", ¶ms_, "delay between runs in seconds"), + CreateFlag("num_threads", ¶ms_, "number of threads"), + CreateFlag("benchmark_name", ¶ms_, "benchmark name"), + CreateFlag("output_prefix", ¶ms_, + "benchmark output prefix"), + CreateFlag("warmup_runs", ¶ms_, + "how many runs to initialize model"), }; } void BenchmarkModel::LogFlags() { - TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]"; - TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay + TFLITE_LOG(INFO) << "Num runs: [" << params_.Get("num_runs") << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" + << params_.Get("run_delay") << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.Get("num_threads") + << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" + << params_.Get("benchmark_name") << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" + << params_.Get("output_prefix") << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get("warmup_runs") << "]"; - TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]"; - TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]"; - TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]"; - TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]"; } Stat BenchmarkModel::Run(int num_times, RunType run_type) { @@ -91,7 +109,7 @@ Stat BenchmarkModel::Run(int num_times, RunType run_type) { listeners_.OnSingleRunEnd(); run_stats.UpdateStat(end_us - start_us); - SleepForSeconds(params_.run_delay); + SleepForSeconds(params_.Get("run_delay")); } std::stringstream stream; @@ -117,8 +135,10 @@ void BenchmarkModel::Run(int argc, char **argv) { << "ms"; uint64_t input_bytes = ComputeInputBytes(); - Stat warmup_time_us = Run(params_.warmup_runs, WARMUP); - Stat inference_time_us = Run(params_.num_runs, REGULAR); + Stat warmup_time_us = + Run(params_.Get("warmup_runs"), WARMUP); + Stat inference_time_us = + Run(params_.Get("num_runs"), REGULAR); listeners_.OnBenchmarkEnd( {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); } diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h index d48f693693c2cee0cd2e2a6f2b4c590998feffb3..942e21f67a7f864f16b7b1b85b2599d5c872b5c7 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" #include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" #include "tensorflow/core/util/stats_calculator.h" @@ -63,17 +64,6 @@ class BenchmarkResults { tensorflow::Stat inference_time_us_; }; -struct BenchmarkParams { - BenchmarkParams() - : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {} - int num_runs; - int warmup_runs; - float run_delay; - int num_threads; - std::string benchmark_name; - std::string output_prefix; -}; - class BenchmarkListener { public: virtual void OnBenchmarkStart(const BenchmarkParams& params) {} @@ -130,12 +120,22 @@ class BenchmarkLoggingListener : public BenchmarkListener { void OnBenchmarkEnd(const BenchmarkResults& results) override; }; +template +Flag CreateFlag(const char* name, BenchmarkParams* params, + const std::string& usage) { + return Flag(name, [params, name](const T& val) { params->Set(name, val); }, + params->Get(name), usage); +} + // Benchmarks a model. // // Subclasses need to implement initialization and running of the model. // The results can be collected by adding BenchmarkListener(s). class BenchmarkModel { public: + static BenchmarkParams DefaultParams(); + BenchmarkModel(); + BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {} virtual ~BenchmarkModel() {} bool ParseFlags(int argc, char** argv); virtual void Init() = 0; diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dcf580a9d4995e6cb3706d3562bc8a2f4670082 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" + +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a, + BenchmarkParam::ParamType b) { + TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter."; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_INT32; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_BOOL; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_FLOAT; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_STRING; +} + +void BenchmarkParams::AssertParamExists(const std::string& name) const { + TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found."; +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h new file mode 100644 index 0000000000000000000000000000000000000000..33448dd1623577fdfda6316c588cc60ccbaa1994 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +template +class TypedBenchmarkParam; + +class BenchmarkParam { + protected: + enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; + + public: + template + static std::unique_ptr Create(const T& default_value) { + return std::unique_ptr( + new TypedBenchmarkParam(default_value)); + } + + template + TypedBenchmarkParam* AsTyped() { + AssertHasSameType(GetValueType(), type_); + return static_cast*>(this); + } + virtual ~BenchmarkParam() {} + BenchmarkParam(ParamType type) : type_(type) {} + + private: + static void AssertHasSameType(ParamType a, ParamType b); + template + static ParamType GetValueType(); + + const ParamType type_; +}; + +template +class TypedBenchmarkParam : public BenchmarkParam { + public: + TypedBenchmarkParam(const T& value) + : BenchmarkParam(GetValueType()), value_(value) {} + void Set(const T& value) { value_ = value; } + + T Get() { return value_; } + + private: + T value_; +}; + +class BenchmarkParams { + public: + void AddParam(const std::string& name, + std::unique_ptr value) { + params_[name] = std::move(value); + } + + bool HasParam(const std::string& name) const { + return params_.find(name) != params_.end(); + } + + template + void Set(const std::string& name, const T& value) { + AssertParamExists(name); + params_.at(name)->AsTyped()->Set(value); + } + + template + T Get(const std::string& name) const { + AssertParamExists(name); + return params_.at(name)->AsTyped()->Get(); + } + + private: + void AssertParamExists(const std::string& name) const; + std::unordered_map> params_; +}; + +} // namespace benchmark +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 5f803cec197858953180d379c763ed7ebd34ee1d..73affc26b034f415ae2a2101e0b558cdb94d8d5b 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -162,15 +162,37 @@ bool PopulateInputLayerInfo( return true; } +BenchmarkParams GetDefaultParams() { + BenchmarkParams default_params = BenchmarkModel::DefaultParams(); + default_params.AddParam("graph", BenchmarkParam::Create("")); + default_params.AddParam("input_layer", + BenchmarkParam::Create("")); + default_params.AddParam("input_layer_shape", + BenchmarkParam::Create("")); + default_params.AddParam("use_nnapi", BenchmarkParam::Create(false)); + return default_params; +} + } // namespace +BenchmarkTfLiteModel::BenchmarkTfLiteModel() + : BenchmarkModel(GetDefaultParams()) { + AddListener(&profiling_listener_); +} + +BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params) + : BenchmarkModel(std::move(params)) { + AddListener(&profiling_listener_); +} + std::vector BenchmarkTfLiteModel::GetFlags() { std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); std::vector specific_flags = { - Flag("graph", &graph, "graph file name"), - Flag("input_layer", &input_layer_string, "input layer names"), - Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("use_nnapi", &use_nnapi, "use nnapi api")}; + CreateFlag("graph", ¶ms_, "graph file name"), + CreateFlag("input_layer", ¶ms_, "input layer names"), + CreateFlag("input_layer_shape", ¶ms_, + "input layer shape"), + CreateFlag("use_nnapi", ¶ms_, "use nnapi api")}; flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); return flags; @@ -178,19 +200,22 @@ std::vector BenchmarkTfLiteModel::GetFlags() { void BenchmarkTfLiteModel::LogFlags() { BenchmarkModel::LogFlags(); - TFLITE_LOG(INFO) << "Graph: [" << graph << "]"; - TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]"; - TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; - TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]"; + TFLITE_LOG(INFO) << "Graph: [" << params_.Get("graph") << "]"; + TFLITE_LOG(INFO) << "Input layers: [" + << params_.Get("input_layer") << "]"; + TFLITE_LOG(INFO) << "Input shapes: [" + << params_.Get("input_layer_shape") << "]"; + TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get("use_nnapi") << "]"; } bool BenchmarkTfLiteModel::ValidateFlags() { - if (graph.empty()) { + if (params_.Get("graph").empty()) { TFLITE_LOG(ERROR) << "Please specify the name of your TF Lite input file with --graph"; return false; } - return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, + return PopulateInputLayerInfo(params_.Get("input_layer"), + params_.Get("input_layer_shape"), &inputs); } @@ -205,6 +230,7 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { } void BenchmarkTfLiteModel::Init() { + std::string graph = params_.Get("graph"); model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); if (!model) { TFLITE_LOG(FATAL) << "Failed to mmap model " << graph; @@ -226,10 +252,14 @@ void BenchmarkTfLiteModel::Init() { } profiling_listener_.SetInterpreter(interpreter.get()); - if (params_.num_threads != -1) { - interpreter->SetNumThreads(params_.num_threads); + const int32_t num_threads = params_.Get("num_threads"); + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); } + bool use_nnapi = params_.Get("use_nnapi"); + interpreter->UseNNAPI(use_nnapi); auto interpreter_inputs = interpreter->inputs(); diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index ffb93da964b2da0328616e749abd9c5a84189468..50cc3f24b3bd2f31555eac69ff208fa2480449b9 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -50,9 +50,8 @@ class ProfilingListener : public BenchmarkListener { // Benchmarks a TFLite model by running tflite interpreter. class BenchmarkTfLiteModel : public BenchmarkModel { public: - BenchmarkTfLiteModel() : use_nnapi(false) { - AddListener(&profiling_listener_); - } + BenchmarkTfLiteModel(); + BenchmarkTfLiteModel(BenchmarkParams params); std::vector GetFlags() override; void LogFlags() override; @@ -70,13 +69,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel { private: std::unique_ptr model; std::unique_ptr interpreter; - std::string graph; - std::string input_layer_string; - std::string input_layer_type_string; - std::string input_layer_shape_string; - std::string input_layer_values_string; std::vector inputs; - bool use_nnapi; ProfilingListener profiling_listener_; }; diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc index 8195fc44beb288eec3c020791b47eefa01536fb7..ff818b9dcb5ee0b58b95c3dceae74083dbd4f0da 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include #include +#include #include namespace tflite { @@ -44,76 +45,79 @@ bool ParseFlag(const std::string& arg, const std::string& flag, } template -bool ParseFlag(const std::string& flag_value, T* value) { +bool ParseFlag(const std::string& flag_value, + const std::function& hook) { std::istringstream stream(flag_value); T read_value; stream >> read_value; if (!stream.eof() && !stream.good()) { return false; } - *value = read_value; + hook(read_value); return true; } -bool ParseBoolFlag(const std::string& flag_value, bool* value) { +bool ParseBoolFlag(const std::string& flag_value, + const std::function& hook) { if (flag_value != "true" && flag_value != "false") { return false; } - *value = (flag_value == "true"); + hook(flag_value == "true"); return true; } - -bool ParseStringFlag(const std::string& flag_value, std::string* value) { - *value = flag_value; - return true; -} - } // namespace -Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text) : name_(name), type_(TYPE_INT32), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text) : name_(name), type_(TYPE_INT64), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, float* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text) : name_(name), type_(TYPE_FLOAT), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, bool* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text) : name_(name), type_(TYPE_BOOL), - value_hook_([dst](const std::string& flag_value) { - return ParseBoolFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseBoolFlag(flag_value, hook); }), - default_for_display_((*dst) ? "true" : "false"), + default_for_display_(default_value ? "true" : "false"), usage_text_(usage_text) {} -Flag::Flag(const char* name, std::string* dst, const std::string& usage_text) +Flag::Flag(const char* name, + const std::function& hook, + const std::string& default_value, const std::string& usage_text) : name_(name), type_(TYPE_STRING), - value_hook_([dst](const std::string& flag_value) { - return ParseStringFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + hook(flag_value); + return true; }), - default_for_display_(*dst), + default_for_display_(default_value), usage_text_(usage_text) {} bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h index 36f9e64767315a317338bc4d2db2ec2d43bee875..2e514ae3ead3b602b8217998ec09177b1e6a2376 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h @@ -33,10 +33,11 @@ namespace tflite { // int some_int = 10; // bool some_switch = false; // std::string some_name = "something"; +// // std::vector flag_list = { -// Flag("some_int", &some_int, "an integer that affects X"), -// Flag("some_switch", &some_switch, "a bool that affects Y"), -// Flag("some_name", &some_name, "a std::string that affects Z") +// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"), +// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"), +// Flag::CreateFlag("some_name", &some_name, "a string that affects Z") // }; // // Get usage message before ParseFlags() to capture default values. // std::string usage = Flag::Usage(argv[0], flag_list); @@ -63,11 +64,21 @@ namespace tflite { // text, and a pointer to the corresponding variable. class Flag { public: - Flag(const char* name, int32_t* dst, const std::string& usage_text); - Flag(const char* name, int64_t* dst, const std::string& usage_text); - Flag(const char* name, bool* dst, const std::string& usage_text); - Flag(const char* name, std::string* dst, const std::string& usage_text); - Flag(const char* name, float* dst, const std::string& usage_text); + template + static Flag CreateFlag(const char* name, T* val, const char* usage) { + return Flag(name, [val](const T& v) { *val = v; }, *val, usage); + } + + Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + const std::string& default_value, const std::string& usage_text); private: friend class Flags; diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc index 620d61b027d30044ba9d449a8e308375f72ad76f..03da8051099899241fa5241374d754adb1aa93c6 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -34,15 +34,15 @@ TEST(CommandLineFlagsTest, BasicUsage) { "--some_name=somethingelse", "--some_float=42.0"}; int argc = 6; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - { - Flag("some_int32", &some_int32, "some int32"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name"), - Flag("some_float", &some_float, "some float"), - }); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + { + Flag::CreateFlag("some_int32", &some_int32, "some int32"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name"), + Flag::CreateFlag("some_float", &some_float, "some float"), + }); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(20, some_int32); @@ -57,9 +57,9 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) { int argc = 2; std::string some_string = "invalid"; const char* argv_strings[] = {"program_name", "--some_string="}; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_string", &some_string, "some string")}); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_string", &some_string, "some string")}); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(some_string, ""); @@ -72,7 +72,7 @@ TEST(CommandLineFlagsTest, BadIntValue) { const char* argv_strings[] = {"program_name", "--some_int=notanumber"}; bool parsed_ok = Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_int", &some_int, "some int")}); + {Flag::CreateFlag("some_int", &some_int, "some int")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(10, some_int); @@ -83,9 +83,9 @@ TEST(CommandLineFlagsTest, BadBoolValue) { bool some_switch = false; int argc = 2; const char* argv_strings[] = {"program_name", "--some_switch=notabool"}; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_switch", &some_switch, "some switch")}); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_switch", &some_switch, "some switch")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(false, some_switch); @@ -98,7 +98,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) { const char* argv_strings[] = {"program_name", "--some_float=notanumber"}; bool parsed_ok = Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_float", &some_float, "some float")}); + {Flag::CreateFlag("some_float", &some_float, "some float")}); EXPECT_EQ(false, parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); @@ -136,10 +136,11 @@ TEST(CommandLineFlagsTest, UsageString) { // match against, and we don't want a flakey test. const std::string tool_name = "some_tool_name"; std::string usage = Flags::Usage( - tool_name + " ", {Flag("some_int", &some_int, "some int"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name")}); + tool_name + " ", + {Flag::CreateFlag("some_int", &some_int, "some int"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name")}); // Match the usage message, being sloppy about whitespace. const char* expected_usage = " usage: some_tool_name \n" diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c8d3307e29efaebdc5c309dc7e4262b54d64943f --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md @@ -0,0 +1,43 @@ +# TFLite iOS benchmark app. + +## Description + +An iOS app to benchmark TFLite models. + +The app reads benchmark parameters from a JSON file named `benchmark_params.json` +in its `benchmark_data` directory. Any downloaded models for benchmarking should +also be placed in `benchmark_data` directory. + +The JSON file specifies the name of the model file and other benchmarking +parameters like inputs to the model, type of inputs, number of iterations, +number of threads. The default values in the JSON file are for the +Mobilenet_1.0_224 model +([paper](https://arxiv.org/pdf/1704.04861.pdf), +[tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)) + +## To build/install/run + +- Follow instructions at [iOS build for TFLite] +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md) +to build TFLite. + +Running + +```bash +tensorflow/contrib/lite/build_ios_universal_lib.sh +``` +will also build `tensorflow/contrib/lite/gen/lib/benchmark-lib.a` . + +- Now copy the downloaded model file to `benchmark_data` directory. + +- Modify `benchmark_params.json` change the `input_layer`, `input_layer_shape` +and other benchmark parameters. + +- Change `Build Phases -> Copy Bundle Resources` and add the model file to the +resources that need to be copied. + +- Ensure that `Build Phases -> Link Binary With Library` contains the +`Accelerate framework` and `tensorflow/contrib/lite/gen/lib/benchmark-lib.a`. + +- Now try running the app. The app has a single button that runs the benchmark + on the model and displays results in a text view below. diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..b908f733d49b56a6b41ebea4185f1fe8c11edc60 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj @@ -0,0 +1,381 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */ = {isa = PBXBuildFile; fileRef = 6FE7579920D59CE500F01636 /* benchmark_params.json */; }; + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */; }; + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579E20D5A6A700F01636 /* Accelerate.framework */; }; + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */; }; + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */; }; + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */; }; + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400120D592D8008C9FE4 /* Main.storyboard */; }; + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400420D592DA008C9FE4 /* Assets.xcassets */; }; + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE9400A20D592DA008C9FE4 /* main.m */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 6FE7579920D59CE500F01636 /* benchmark_params.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = benchmark_params.json; sourceTree = ""; }; + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "benchmark-lib.a"; path = "$SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib/benchmark-lib.a"; sourceTree = ""; }; + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TFLiteBenchmark.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.h; path = BenchmarkViewController.h; sourceTree = ""; }; + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = ""; }; + 6FE9400220D592D8008C9FE4 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 6FE9400920D592DA008C9FE4 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 6FE9400A20D592DA008C9FE4 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 6FE93FF520D592D8008C9FE4 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */, + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 6FE7579820D59C8B00F01636 /* benchmark_data */ = { + isa = PBXGroup; + children = ( + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */, + 6FE7579920D59CE500F01636 /* benchmark_params.json */, + ); + path = benchmark_data; + sourceTree = ""; + }; + 6FE7579B20D5A5E000F01636 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */, + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 6FE93FEF20D592D8008C9FE4 = { + isa = PBXGroup; + children = ( + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */, + 6FE93FF920D592D8008C9FE4 /* Products */, + 6FE7579B20D5A5E000F01636 /* Frameworks */, + ); + sourceTree = ""; + }; + 6FE93FF920D592D8008C9FE4 /* Products */ = { + isa = PBXGroup; + children = ( + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */, + ); + name = Products; + sourceTree = ""; + }; + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXGroup; + children = ( + 6FE7579820D59C8B00F01636 /* benchmark_data */, + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */, + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */, + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */, + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */, + 6FE9400120D592D8008C9FE4 /* Main.storyboard */, + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */, + 6FE9400920D592DA008C9FE4 /* Info.plist */, + 6FE9400A20D592DA008C9FE4 /* main.m */, + ); + path = TFLiteBenchmark; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXNativeTarget; + buildConfigurationList = 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */; + buildPhases = ( + 6FE93FF420D592D8008C9FE4 /* Sources */, + 6FE93FF520D592D8008C9FE4 /* Frameworks */, + 6FE93FF620D592D8008C9FE4 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = TFLiteBenchmark; + productName = TFLiteBenchmark; + productReference = 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 6FE93FF020D592D8008C9FE4 /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 1000; + ORGANIZATIONNAME = Example; + TargetAttributes = { + 6FE93FF720D592D8008C9FE4 = { + CreatedOnToolsVersion = 10.0; + }; + }; + }; + buildConfigurationList = 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 6FE93FEF20D592D8008C9FE4; + productRefGroup = 6FE93FF920D592D8008C9FE4 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 6FE93FF620D592D8008C9FE4 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */, + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */, + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */, + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 6FE93FF420D592D8008C9FE4 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */, + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */, + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 6FE9400120D592D8008C9FE4 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 6FE9400220D592D8008C9FE4 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 6FE9400C20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + ONLY_ACTIVE_ARCH = YES; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + }; + name = Debug; + }; + 6FE9400D20D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = NO; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 6FE9400F20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + "USER_HEADER_SEARCH_PATHS[arch=*]" = ""; + }; + name = Debug; + }; + 6FE9401020D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400C20D592DA008C9FE4 /* Debug */, + 6FE9400D20D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400F20D592DA008C9FE4 /* Debug */, + 6FE9401020D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 6FE93FF020D592D8008C9FE4 /* Project object */; +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..a55c03e00b5065e3b149c65f820f11d13c064d87 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h @@ -0,0 +1,22 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m new file mode 100644 index 0000000000000000000000000000000000000000..b1165940e9a29ac693d473a1c852b7b0681392fc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m @@ -0,0 +1,27 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + return YES; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..d8db8d65fd79fd541b2b7eba75c7378af3448f9c --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,98 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..bfa36129419f8bd7ad73581cb9f07b8c6eec3fcf --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..adcfe1ef4e708ea6f87c77f4a740b58e5027d3e5 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..ec6dea0546060881682c44ad451f4812a2f3d7ea --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h @@ -0,0 +1,21 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface BenchmarkViewController : UIViewController +@property(weak, nonatomic) IBOutlet UITextView *resultsView; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..356d5b0e17abc715de9b8f7a20ec7459f3468da1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm @@ -0,0 +1,125 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "BenchmarkViewController.h" +#import +#import +#import +#import +#import "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#import "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace { +NSString* FilePathForResourceName(NSString* filename) { + NSString* name = [filename stringByDeletingPathExtension]; + NSString* extension = [filename pathExtension]; + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + TFLITE_LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +NSDictionary* ParseJson() { + NSString* params_json_path = FilePathForResourceName(@"benchmark_params.json"); + NSData* data = [NSData dataWithContentsOfFile:params_json_path]; + return [NSJSONSerialization JSONObjectWithData:data options:kNilOptions error:nil]; +} + +std::string FormatCommandLineParam(NSString* key, NSString* value) { + std::ostringstream stream; + stream << "--" << [key UTF8String] << "=" << [value UTF8String]; + return stream.str(); +} + +// Reads the |benchmark_params.json| to read command line parameters and returns them as a vector of +// strings. +void ReadCommandLineParameters(std::vector* params) { + NSDictionary* param_dict = ParseJson(); + for (NSString* key in param_dict) { + NSString* value = param_dict[key]; + if ([key isEqualToString:@"graph"]) { + value = FilePathForResourceName(value); + } + params->push_back(FormatCommandLineParam(key, value)); + } +} +std::vector StringVecToCharPtrVec(const std::vector& str_vec) { + std::vector charptr_vec; + std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(charptr_vec), + [](const std::string& s) -> char* { return const_cast(s.c_str()); }); + return charptr_vec; +} + +class ResultsListener : public tflite::benchmark::BenchmarkListener { + public: + void OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) override; + std::string Results() { return results_; } + + private: + std::string results_; +}; + +void OutputMicrosecondsStatToStream(const tensorflow::Stat& time_us, + const std::string& prefix, std::ostringstream* stream) { + *stream << prefix << "Num runs: " << time_us.count() << "\n"; + + *stream << prefix << "Average: " << time_us.avg() / 1e3 << " ms\n"; + *stream << prefix << "Min: " << time_us.min() / 1e3 << " ms \n"; + *stream << prefix << "Max: " << time_us.max() / 1e3 << " ms \n"; + *stream << prefix << "Std deviation: " << time_us.std_deviation() / 1e3 << " ms\n"; +} + +void ResultsListener::OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) { + std::ostringstream stream; + const std::string prefix = " - "; + stream << "Startup latency: "; + stream << results.startup_latency_us() / 1e3 << " ms\n"; + stream << "\nInference:\n"; + OutputMicrosecondsStatToStream(results.inference_time_us(), prefix, &stream); + stream << "\nWarmup:\n"; + OutputMicrosecondsStatToStream(results.warmup_time_us(), prefix, &stream); + + results_ = stream.str(); +} + +std::string RunBenchmark() { + ResultsListener listener; + tflite::benchmark::BenchmarkTfLiteModel benchmark; + benchmark.AddListener(&listener); + // TODO(shashishekhar): Passing arguments like this is brittle, refactor the BenchmarkParams + // so that it contains arguments for BenchmarkTfLiteModel and set parameters using BenchmarkParams + std::vector command_line_params; + // Benchmark model expects first arg to be program name. + // push a string for name of program. + command_line_params.push_back("benchmark_tflite_model"); + ReadCommandLineParameters(&command_line_params); + std::vector argv = StringVecToCharPtrVec(command_line_params); + int argc = static_cast(argv.size()); + benchmark.Run(argc, argv.data()); + return listener.Results(); +} +} // namespace + +@interface BenchmarkViewController () +@end + +@implementation BenchmarkViewController +- (IBAction)onBenchmarkModel:(UIButton*)sender { + std::string results = RunBenchmark(); + [_resultsView setText:[NSString stringWithUTF8String:results.c_str()]]; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..96051cf08ff54b51f458eca6f0126dd99dfc51dc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist @@ -0,0 +1,43 @@ + + + + + UILaunchStoryboardName + Main + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json new file mode 100644 index 0000000000000000000000000000000000000000..d344a7a5efaef53500bc0f88d29ca7aecf59290a --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json @@ -0,0 +1,10 @@ +{ + "benchmark_name" : "mobile_net_benchmark", + "num_threads" : "4", + "num_runs" : "20", + "warmup_runs" : "1", + "graph" : "mobilenet_v1_1.0_224.tflite", + "input_layer" : "input", + "input_layer_shape" : "1,224,224,3", + "run_delay" : "-1" +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m new file mode 100644 index 0000000000000000000000000000000000000000..1e70b9cd1d82f320ec048642520dbc54dc0f7934 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m @@ -0,0 +1,23 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char* argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5a080cceabb55c307dcd1a457a9e30d24e0bd172..889accdd5aafae2931048ffdd26408cccb3c874e 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1397,7 +1397,7 @@ class KeyValueTensorInitializerTest(test.TestCase): class IndexTableFromTensor(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_index_table_from_tensor_with_tensor_init(self): table = lookup.index_table_from_tensor( mapping=("brain", "salad", "surgery"), num_oov_buckets=1) @@ -1670,7 +1670,7 @@ class InitializeTableFromFileOpTest(test.TestCase): f.write("\n".join(values) + "\n") return vocabulary_file - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeStringTable(self): vocabulary_file = self._createVocabFile("one_column_1.txt") default_value = -1 diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index fc88f59e0948e1d3ed7cce9b809bf30ba280af12..fb9e77ae1bcfc3404f1fdf90ab2697a4e79a9836 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -30,6 +30,14 @@ arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64 tegra)" exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + if [[ -z "${NDK_ROOT}" ]]; then echo "NDK_ROOT should be set as an environment variable" 1>&2 exit 1 diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index 0a458a27b3ac9b1a24b0f42de2f0166d515e8cd9..1d4677ef4bd1e8811998d1464e63902544153a49 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -31,6 +31,14 @@ usage() { exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + DEFAULT_ARCH="i386 x86_64 armv7 armv7s arm64" while getopts "a:g:T" opt_name; do case "$opt_name" in diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 89db9ee2794ddf0a99951dca327e74c5d9694d23..6e7423f85e3b66e2f40b25c0b83d0fcaa54817a9 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc tensorflow/core/kernels/reduction_ops_any.cc tensorflow/core/kernels/reduction_ops_all.cc tensorflow/core/kernels/roll_op.cc +tensorflow/core/kernels/queue_op.cc tensorflow/core/kernels/queue_ops.cc tensorflow/core/kernels/queue_base.cc tensorflow/core/kernels/pooling_ops_common.cc diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index a6be2084aae6bb05f958929b45977ed21b570603..b14202ff9ec38016f926ee37c8acbd2bbb4c6ef5 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1064,7 +1064,7 @@ def streaming_auc(predictions, name=name) -def _compute_dynamic_auc(labels, predictions, curve='ROC'): +def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. Computes the area under the ROC or PR curve using each prediction as a @@ -1077,13 +1077,22 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): predictions: A 1-D `Tensor` of predictions whose values are `float64`. curve: The name of the curve to be computed, 'ROC' for the Receiving Operating Characteristic or 'PR' for the Precision-Recall curve. + weights: A 1-D `Tensor` of weights whose values are `float64`. Returns: A scalar `Tensor` containing the area-under-curve value for the input. """ - # Count the total number of positive and negative labels in the input. + # Compute the total weight and the total positive weight. size = array_ops.size(predictions) - total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32) + if weights is None: + weights = array_ops.ones_like(labels, dtype=dtypes.float64) + labels, predictions, weights = metrics_impl._remove_squeezable_dimensions( + labels, predictions, weights) + total_weight = math_ops.reduce_sum(weights) + total_positive = math_ops.reduce_sum( + array_ops.where( + math_ops.greater(labels, 0), weights, + array_ops.zeros_like(labels, dtype=dtypes.float64))) def continue_computing_dynamic_auc(): """Continues dynamic auc computation, entered if labels are not all equal. @@ -1091,9 +1100,11 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): Returns: A scalar `Tensor` containing the area-under-curve value. """ - # Sort the predictions descending, and the corresponding labels as well. + # Sort the predictions descending, keeping the same order for the + # corresponding labels and weights. ordered_predictions, indices = nn.top_k(predictions, k=size) ordered_labels = array_ops.gather(labels, indices) + ordered_weights = array_ops.gather(weights, indices) # Get the counts of the unique ordered predictions. _, _, counts = array_ops.unique_with_counts(ordered_predictions) @@ -1103,23 +1114,39 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32) # Count the positives to the left of the split indices. - positives = math_ops.cast( - array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]), - dtypes.int32) - true_positives = array_ops.gather(positives, splits) + true_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.greater(ordered_labels, 0), ordered_weights, + array_ops.zeros_like(ordered_labels, + dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) if curve == 'ROC': - # Count the negatives to the left of every split point and the total - # number of negatives for computing the FPR. - false_positives = math_ops.subtract(splits, true_positives) - total_negative = size - total_positive + # Compute the weight of the negatives to the left of every split point and + # the total weight of the negatives number of negatives for computing the + # FPR. + false_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.less(ordered_labels, 1), ordered_weights, + array_ops.zeros_like( + ordered_labels, dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) + total_negative = total_weight - total_positive x_axis_values = math_ops.truediv(false_positives, total_negative) y_axis_values = math_ops.truediv(true_positives, total_positive) elif curve == 'PR': x_axis_values = math_ops.truediv(true_positives, total_positive) # For conformance, set precision to 1 when the number of positive # classifications is 0. + positives = array_ops.gather( + array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]), + splits) y_axis_values = array_ops.where( - math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits), + math_ops.greater(splits, 0), + math_ops.truediv(true_positives, positives), array_ops.ones_like(true_positives, dtype=dtypes.float64)) # Calculate trapezoid areas. @@ -1133,7 +1160,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): return control_flow_ops.cond( math_ops.logical_or( math_ops.equal(total_positive, 0), math_ops.equal( - total_positive, size)), + total_positive, total_weight)), true_fn=lambda: array_ops.constant(0, dtypes.float64), false_fn=continue_computing_dynamic_auc) @@ -1143,7 +1170,8 @@ def streaming_dynamic_auc(labels, curve='ROC', metrics_collections=(), updates_collections=(), - name=None): + name=None, + weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. USAGE NOTE: this approach requires storing all of the predictions and labels @@ -1168,6 +1196,8 @@ def streaming_dynamic_auc(labels, should be added to. name: An optional name for the variable_scope that contains the metric variables. + weights: A 'Tensor' of non-negative weights whose values are castable to + `float64`. Will be flattened into a 1-D `Tensor`. Returns: auc: A scalar `Tensor` containing the current area-under-curve value. @@ -1195,14 +1225,24 @@ def streaming_dynamic_auc(labels, check_ops.assert_less_equal( labels, array_ops.ones_like(labels, dtypes.int64), - message='labels must be 0 or 1, at least one is >1') + message='labels must be 0 or 1, at least one is >1'), ]): preds_accum, update_preds = streaming_concat( predictions, name='concat_preds') labels_accum, update_labels = streaming_concat( labels, name='concat_labels') - update_op = control_flow_ops.group(update_labels, update_preds) - auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve) + if weights is not None: + weights = array_ops.reshape( + math_ops.cast(weights, dtypes.float64), [-1]) + weights_accum, update_weights = streaming_concat( + weights, name='concat_weights') + update_op = control_flow_ops.group(update_labels, update_preds, + update_weights) + else: + weights_accum = None + update_op = control_flow_ops.group(update_labels, update_preds) + auc = _compute_dynamic_auc( + labels_accum, preds_accum, curve=curve, weights=weights_accum) if updates_collections: ops.add_to_collections(updates_collections, update_op) if metrics_collections: diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index e720097636fdbe767ca3180345ecd93504c89d55..a09fc4abd461323d67e914c70932688816fed764 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2127,6 +2127,44 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) + def testWithWeights(self): + batch_size = 10 + num_batches = 100 + labels = np.array([]) + predictions = np.array([]) + weights = np.array([]) + tf_labels = variables.Variable( + array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) + tf_predictions = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + tf_weights = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + auc, update_op = metrics.streaming_dynamic_auc(tf_labels, + tf_predictions, + weights=tf_weights) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_batches): + new_labels = np.random.randint(0, 2, size=batch_size) + noise = np.random.uniform(-0.2, 0.2, size=batch_size) + new_predictions = 0.4 + 0.2 * new_labels + noise + new_weights = np.random.uniform(0.0, 3.0, size=batch_size) + labels = np.concatenate([labels, new_labels]) + predictions = np.concatenate([predictions, new_predictions]) + weights = np.concatenate([weights, new_weights]) + sess.run([tf_labels.assign(new_labels), + tf_predictions.assign(new_predictions), + tf_weights.assign(new_weights)]) + sess.run(update_op) + expected_auc = _np_auc(predictions, labels, weights) + self.assertAlmostEqual(expected_auc, auc.eval()) + class AucWithConfidenceIntervalsTest(test.TestCase): diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py index 480f5f6eaf493c5c87c27cc9f8e510ea9c085a72..1b0383d24c0c472b4875d15c3650e37dfd2439e1 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -34,7 +34,7 @@ def _GetExampleIter(inputs): class FixedLossScaleManagerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_basic(self): itr = _GetExampleIter([True] * 10 + [False] * 10) @@ -84,13 +84,13 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): actual_outputs.append(self.evaluate(lsm.get_loss_scale())) self.assertEqual(actual_outputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increase_every_n_steps(self): inputs = [True] * 6 expected_outputs = [1, 2, 2, 4, 4, 8] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_keep_increasing_until_capped(self): init_loss_scale = np.finfo(np.float32).max / 4 + 10 max_float = np.finfo(np.float32).max @@ -104,7 +104,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decrease_every_n_steps(self): inputs = [False] * 6 init_loss_scale = 1024 @@ -112,7 +112,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_keep_decreasing_until_one(self): inputs = [False] * 10 init_loss_scale = 16 @@ -120,19 +120,19 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_incr_bad_step_clear_good_step(self): inputs = [True, True, True, False, True] expected_outputs = [1, 2, 2, 2, 2] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_incr_good_step_does_not_clear_bad_step(self): inputs = [True, True, True, False, True, False] expected_outputs = [1, 2, 2, 2, 2, 1] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trigger_loss_scale_update_each_step(self): """Test when incr_every_n_step and decr_every_n_nan_or_inf is 1.""" init_loss_scale = 1 @@ -145,7 +145,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_alternating_good_and_bad_gradients_trigger_each_step(self): init_loss_scale = 1 incr_every_n_step = 1 @@ -156,7 +156,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self): init_loss_scale = 32 incr_every_n_step = 2 @@ -167,7 +167,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_random_mix_good_and_bad_gradients(self): init_loss_scale = 4 inputs = [ diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py index dded61ccd58eb79b338d7264e8a057c9456c8695..9009df0eefec13146090ba5fc2096e71ba6eb89d 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -54,7 +54,7 @@ class LossScaleOptimizerTest(test.TestCase): opt = loss_scale_opt_fn(opt) return x, loss, opt - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_float16_underflow_without_loss_scale(self): lr = 1 init_val = 1. @@ -73,7 +73,7 @@ class LossScaleOptimizerTest(test.TestCase): rtol=0, atol=min(symbolic_update, 1e-6)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_float16_with_loss_scale(self): lr = 1. init_val = 1. @@ -95,7 +95,7 @@ class LossScaleOptimizerTest(test.TestCase): rtol=0, atol=min(expected_update, 1e-6)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_compute_gradients_with_loss_scale(self): lr = 1 init_val = 1. @@ -115,7 +115,7 @@ class LossScaleOptimizerTest(test.TestCase): # Gradients aren't applied. self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_compute_gradients_without_loss_scale(self): lr = 1 init_val = 1. @@ -127,7 +127,7 @@ class LossScaleOptimizerTest(test.TestCase): g_v = self.evaluate(grads_and_vars[0][0]) self.assertAllClose(g_v, 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_apply_gradients(self): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) @@ -155,7 +155,7 @@ class LossScaleOptimizerTest(test.TestCase): actual_output.append(self.evaluate(x)) self.assertAllClose(expected_output, actual_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_apply_gradients_loss_scale_is_updated(self): class SimpleLossScaleManager(lsm_lib.LossScaleManager): diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index a7be92a35e0d62a61f7923ac61bb2c1267d039c6..ecac06354d2ce796f2a6021cdf2370d7c30ccab7 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,6 +52,7 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", + "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc index ed22ee667f1d73b3f86f77e09bad9bfec7e46391..e4b0c2c6541836243347d2950686c60ef06d2bfc 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc @@ -73,7 +73,7 @@ limitations under the License. */ template -using StatusOr = se::port::StatusOr; +using StatusOr = stream_executor::port::StatusOr; using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 334e70318dd88185cecd93ebeb2587861b7999b9..62996d1fd83f46145e9a1b773b1be57e27903127 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -19,17 +19,18 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") tf_custom_op_library( name = "python/ops/_nccl_ops.so", srcs = [ "ops/nccl_ops.cc", ], - gpu_srcs = [ + gpu_srcs = if_not_windows_cuda([ "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", - ], + ]), deps = if_cuda([ "@local_config_nccl//:nccl", "//tensorflow/core:gpu_headers_lib", @@ -97,18 +98,19 @@ tf_gen_op_wrapper_py( deps = [":nccl_ops_op_lib"], ) +# Test only nccl ops lib without dso to test behavior when NCCL lib is not +# installed. See nccl_dependency_test for more details. +# +# Users should use the public nccl_py lib that also adds the dso. tf_custom_op_py_library( - name = "nccl_py", + name = "nccl_ops_lib_without_dso", srcs = [ "__init__.py", "python/ops/nccl_ops.py", ], - dso = [":python/ops/_nccl_ops.so"], kernels = if_cuda([":nccl_kernels"]) + [ ":nccl_ops_op_lib", ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ ":nccl_ops", "//tensorflow/contrib/util:util_py", @@ -120,6 +122,15 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "nccl_py", + dso = [":python/ops/_nccl_ops.so"], + visibility = ["//visibility:public"], + deps = [ + ":nccl_ops_lib_without_dso", + ], +) + cuda_py_test( name = "nccl_ops_test", size = "small", @@ -141,3 +152,25 @@ cuda_py_test( "notap", ], ) + +cuda_py_test( + name = "nccl_dependency_test", + size = "small", + srcs = ["python/ops/nccl_dependency_test.py"], + additional_deps = [ + ":nccl_ops_lib_without_dso", + "//tensorflow/python:constant_op", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], + # Disable this test internally as static linking is used internally and only + # run for OSS to verify that NCCL is an optional dynamic dependency. + tags = [ + "manual", + "noguitar", + "notap", + ], +) diff --git a/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c766080dbee7c9a6f4383ef6fa8cade7bba158af --- /dev/null +++ b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dependency test for nccl to test behavior when NCCL is not installed.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import nccl +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +class NcclDependencyTest(test.TestCase): + """Verifies that importing nccl ops lib does not fail even if NCCL is not + installed but nccl ops throws an exception on use if NCCL is not installed. + """ + + def test_nccl_ops(self): + """Tests behavior of nccl ops when NCCL is not installed.""" + + public_methods = [ + m[0] + for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction) + if not m[0].startswith('_') + ] + for method_name in public_methods: + with ops.device('/device:CPU:0'): + tensor = constant_op.constant(1) + + if method_name == 'broadcast': + arg = tensor + else: + arg = [tensor] + + nccl_op = getattr(nccl, method_name) + with ops.device('/device:CPU:0'): + with self.assertRaisesRegexp(errors_impl.NotFoundError, + r'cannot open shared object file'): + nccl_op(arg) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 794372a1f4b0dcc41bcf0da611f5bc2ec9301973..029b01412d96ca03d4ecf7bf4d7d9872864e3ddc 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -26,8 +26,10 @@ from tensorflow.python.framework import device from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader -_nccl_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile('_nccl_ops.so')) + +_nccl_ops_so = None +_module_lock = threading.Lock() +_shared_name_counter = 0 def all_sum(tensors): @@ -180,7 +182,7 @@ def broadcast(tensor): A tensor with the value of `src_tensor`, which can be used as input to ops on other GPU devices. """ - _check_graph_mode() + _validate_and_load_nccl_so() _check_device(tensor) with ops.device(tensor.device): @@ -212,7 +214,7 @@ def _apply_all_reduce(reduction, tensors): """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() shared_name = _get_shared_name() res = [] @@ -234,7 +236,7 @@ def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() for t in tensors: _check_device(t) @@ -246,14 +248,10 @@ def _apply_reduce(reduction, tensors): return result -_lock = threading.Lock() -_shared_name_counter = 0 - - def _get_shared_name(): global _shared_name_counter - with _lock: + with _module_lock: val = _shared_name_counter _shared_name_counter += 1 return 'c%s' % val @@ -266,6 +264,25 @@ def _check_device(tensor, expected=None): raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) -def _check_graph_mode(): +def _maybe_load_nccl_ops_so(): + """Loads nccl ops so if it hasn't been loaded already.""" + + with _module_lock: + global _nccl_ops_so + if not _nccl_ops_so: + _nccl_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile('_nccl_ops.so')) + + +def _validate_and_load_nccl_so(): + """Validates calling context and loads nccl ops so file. + + Raises: + ValueError: Ops are not supported. + errors_impl.NotFoundError: nccl library is not installed. + """ + if context.executing_eagerly(): raise ValueError('Nccl ops are not supported in eager mode') + + _maybe_load_nccl_ops_so() diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 114b344d38413208755a47f36f45badc1a5ecaa9..bbdf962d0480e52045d31f65b3d137ed3f11f2f1 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -19,6 +19,7 @@ py_library( "python/training/drop_stale_gradient_optimizer.py", "python/training/elastic_average_optimizer.py", "python/training/external_optimizer.py", + "python/training/ggt.py", "python/training/lazy_adam_optimizer.py", "python/training/model_average_optimizer.py", "python/training/moving_average_optimizer.py", @@ -32,12 +33,15 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", @@ -322,3 +326,21 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "ggt_test", + srcs = ["python/training/ggt_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 5df5d35f8e4f8fcc2c5aa09bd8f3254e16e3a74f..3e63e99030c46c254625ca8fdccce614cd60e8b0 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,16 +22,18 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * +from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * -from tensorflow.contrib.opt.python.training.model_average_optimizer import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -58,7 +60,8 @@ _allowed_symbols = [ 'ElasticAverageOptimizer', 'ElasticAverageCustomGetter', 'ModelAverageOptimizer', - 'ModelAverageCustomGetter' + 'ModelAverageCustomGetter', + 'GGTOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py new file mode 100644 index 0000000000000000000000000000000000000000..928c453517f825ed2d305ec498d07ac29c065f1a --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -0,0 +1,312 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GGT for Tensorflow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops + + +class GGTOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the GGT algorithm. + + GGT has an advantage over sgd and adam on large models with poor conditioning, + for example language models and CNNs, + see [ABCHSZZ 2018]([pdf](https://arxiv.org/pdf/1806.02958.pdf)). + """ + + def __init__(self, + learning_rate=0.001, + beta1=0.9, + use_locking=False, + name="GGT", + window=10, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Construct a new GGT optimizer. + + Initialization: + + ``` + t <- 0 (Initialize timestep) + grad_buffer <- 0 (Initialize buffer for keeping past gradients) + flat_grad <- 0 (Initialize flattened gradient that contains gradients of all + variables) + m_0 <- 0 (Initialize 1st moment vector) + ``` + + Suppose all variables and their gradients are concatenated into vectors + `flat_vars` and `flat_grad`. The update rule for `flat_vars` + uses an optimization described at the beginning of section 2 of the paper: + + ``` + t <- t + 1 + + m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad + grad_buffer[(t-1) % window, :] <- m_t + + M <- grad_buffer^T / sqrt(min(t, window)) + U, sigma, _ <- SVD(M^TM + I * svd_eps) + + sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3) + sigma_sqrt_min <- min(sqrt(sigma)) + + if sigma_sqrt_min > eps: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min + else: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + flat_vars <- flat_vars - learning_rate * new_step + ``` + + GGT provides the power of full-matrix adaptive regularization at a cost not + much larger than SGD. As a result it is suited for large models where the + gradient covariance matrix has a poor condition number that slows down first + order methods. + GGT uses the preconditioner from full-matrix AdaGrad, with gradient history + attenuated exponentially as in Adam, and truncated to a window parameter. + It has provable guarantees even for non-convex optimization that is never + significantly worse than SGD and in some cases better. + + Args: + learning_rate: A float hyperparameter. The learning rate. + beta1: A float hyperparameter. The exponential decay rate for the 1st + moment estimates. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "GGT". + window: An integer hyperparameter. The number of first moments to keep in + computing the adaptive preconditioner. + eps: A float hyperparameter. Used to truncate small eigenvalues of the + gradient covariance matrix. + svd_eps: A float hyperparameter. Used to stabilize SVD. + sigma_eps: A float hyperparameter. Used to regularize matrix inversion. + """ + super(GGTOptimizer, self).__init__(use_locking, name) + self._set_hyper("lr", learning_rate) + self._set_hyper("beta1", beta1) + self._set_hyper("window", window) + self._set_hyper("eps", eps) + self._set_hyper("svd_eps", svd_eps) + self._set_hyper("sigma_eps", sigma_eps) + + self.index_dict = {} + self.shape_dict = {} + + def _create_vars(self, var_list, state): + # Construct ordered dictionary for variable dimensions, sorted by name. + shape_dict = {} + for v in var_list: + shape_dict[v.name] = np.prod(v.get_shape()).value + self.shape_dict = collections.OrderedDict( + sorted(shape_dict.items(), key=lambda t: t[0])) + + # Assign each variable its location in flat_grad. The locations are based on + # the order of sorted names. + idx = 0 + for v_name, v_dim in self.shape_dict.items(): + self.index_dict[v_name] = idx + idx += v_dim + + state.create_non_slot( + initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype), + name="global_step") + + # Buffer for keeping past gradients. + window = state.get_hyper("window") + grad_buffer_init = array_ops.zeros( + [window, idx], dtype=var_list[0].dtype.base_dtype) + state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer") + + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="moment1") + + # Flattened gradient that contains gradients for all variables in the model. + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="flat_grad") + + def _get_global_step(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("global_step") + + def _get_moment1(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("moment1") + + def _get_grad_buffer(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("grad_buffer") + + def _get_flat_grad(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("flat_grad") + + def _apply_sparse(self, grad, var): + raise NotImplementedError("Sparse gradient updates are not supported.") + + def _prepare(self, state): + self._variables = [] + + def _apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _resource_apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _finish(self, state): + var_dtype = self._variables[0].dtype.base_dtype + # Update global step. + global_step = self._get_global_step(state) + update_global_step = state_ops.assign_add(global_step, 1.) + + # Update the first moment estimate. + beta1 = state.get_hyper("beta1", dtype=var_dtype) + moment1 = self._get_moment1(state) + flat_grad = self._get_flat_grad(state) + # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t + update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad) + + # Update the gradient buffer. + window = state.get_hyper("window") + grad_buffer = self._get_grad_buffer(state) + next_grad_index = math_ops.floormod( + math_ops.to_int32(update_global_step - 1.), window) + # grad_buffer[(t-1) % window] := moment1_t + update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, + update_moment1) + + # Compute the update step. + eps = state.get_hyper("eps", dtype=var_dtype) + svd_eps = state.get_hyper("svd_eps", dtype=var_dtype) + sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype) + lr = state.get_hyper("lr", dtype=var_dtype) + denom = math_ops.sqrt( + math_ops.minimum( + ops.convert_to_tensor(update_global_step), + ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype)))) + moment1_2d = array_ops.expand_dims(update_moment1, -1) + + # m = grad_buffer^T / sqrt(min(t, window)) + # m has shape [model dimension, window], where model dimension is the sum + # of the dimensions of the flattened variables. + m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom)) + + # sigma, u, _ = SVD(m^Tm + I * svd_eps) + mm = math_ops.matmul(m, m, transpose_a=True) + damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps + sigma, u, _ = linalg_ops.svd(mm + damping) + sigma_sqrt = math_ops.sqrt(sigma) + sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt) + + # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3 + # We add sigma_eps to alleviate numerical instability. + # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T. + sigma_sqrt_inv = math_ops.divide( + math_ops.cast(1.0, dtype=var_dtype), + math_ops.pow(sigma_sqrt + sigma_eps, 3)) + + # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the + # inversion of a model dimension by model dimension matrix is needed. To + # speed up this computation we calculate the following instead: + # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1. + new_step = array_ops.expand_dims( + array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1) + head = math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag(sigma_sqrt_inv), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + + # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for + # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using + # Woodbury's identity. + # For full derivation please see paper at + # https://arxiv.org/pdf/1806.02958.pdf + tail = moment1_2d - math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag( + math_ops.divide(math_ops.cast(1.0, dtype=var_dtype), + sigma)), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + scaled_tail = math_ops.divide(tail, sigma_sqrt_min) + + update_new_step = control_flow_ops.cond( + sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail), + lambda: math_ops.add(new_step, head)) + + # Update each variable. + update_step = [] + for var in self._variables: + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + var_update_correct_shape = array_ops.reshape( + update_new_step[start_index:end_index], var.get_shape()) + var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape) + update_step.append(var_updated) + + return control_flow_ops.group(update_step) diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py new file mode 100644 index 0000000000000000000000000000000000000000..42162960b049cd90c663989fb4fc9d7f179a84ff --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt_test.py @@ -0,0 +1,183 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GGTOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def ggt_update_numpy(param, + g_t, + lr, + grad_buffer, + m, + window, + t, + beta1=0.9, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Tests the correctness of one step of GGT.""" + m_t = m * beta1 + (1 - beta1) * g_t + grad_buffer[((t - 1) % window), :] = m_t + m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window))) + mm = np.dot(np.transpose(m_matrix), m_matrix) + damping = np.eye(window) * svd_eps + u, sigma, _ = np.linalg.svd(mm + damping) + + sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3) + new_step = np.linalg.multi_dot([ + m_matrix, u, + np.diag(sigma_sqrt_inv), + np.transpose(u), + np.transpose(m_matrix), m_t + ]) + + sigma_sqrt_min = np.sqrt(sigma).min() + + if sigma_sqrt_min > eps: + new_step += (m_t - np.linalg.multi_dot([ + m_matrix, u, + np.diag(1.0 / sigma), + np.transpose(u), + np.transpose(m_matrix), m_t + ])) * (1.0 / sigma_sqrt_min) + + param_t = param - lr * new_step + return param_t, m_t, grad_buffer + + +class GGTOptimizerTest(test.TestCase): + + def doTestBasic(self, use_resource=False): + # SVD does not support float16 + for i, dtype in enumerate([dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0 = 0.0 + window = 3 + grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype) + lr = 0.001 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np, name="var0") + var1 = variables.Variable(var1_np, name="var1") + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = GGTOptimizer(learning_rate=lr, window=window) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + self.assertTrue(m_t is not None) + self.assertTrue(grad_buffer_t is not None) + self.assertTrue(g_t is not None) + self.assertIn(m_t, opt_variables) + self.assertIn(grad_buffer_t, opt_variables) + self.assertIn(g_t, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + + # Run 3 steps of GGT + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if t == 1: + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.], + [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t)) + elif t == 2: + self.assertAllCloseAccordingToType( + np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], + [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]), + self.evaluate(grad_buffer_t)) + else: + self.assertAllCloseAccordingToType( + np.array([0.0271, 0.0271, 0.00271, 0.00271]), + self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, + 0.001], [0.019, 0.019, 0.0019, 0.0019], + [0.0271, 0.0271, 0.00271, 0.00271]]), + self.evaluate(grad_buffer_t)) + + self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01], + self.evaluate(g_t)) + + var_np = np.append(var0_np, var1_np) + grads_np = np.append(grads0_np, grads1_np) + var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr, + grad_buffer, m0, window, t) + + var0_np = var_np[:2] + var1_np = var_np[2:] + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 8aa40aeb45d4ec15140bdfc5ebd824e8aa08d8d9..b9cf40eb7b2d11c98b93c51213145ca4e2670318 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.training import optimizer from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops from tensorflow.python.training import adam from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.training import optimizer from tensorflow.python.util.tf_export import tf_export -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import resource_variable_ops class DecoupledWeightDecayExtension(object): @@ -65,7 +65,7 @@ class DecoupledWeightDecayExtension(object): Args: weight_decay: A `Tensor` or a floating point value, the factor by which a variable is decayed in the update step. - decay_var_list: Optional list or tuple or set of `Variable` objects to + **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ self._decay_var_list = None # is set in minimize or apply_gradients @@ -85,6 +85,28 @@ class DecoupledWeightDecayExtension(object): If decay_var_list is None, all variables in var_list are decayed. For more information see the documentation of Optimizer.minimize. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of decay variables. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).minimize( @@ -103,6 +125,19 @@ class DecoupledWeightDecayExtension(object): are decayed. For more information see the documentation of Optimizer.apply_gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of decay variables. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).apply_gradients( @@ -197,6 +232,7 @@ def extend_with_decoupled_weight_decay(base_optimizer): A new optimizer class that inherits from DecoupledWeightDecayExtension and base_optimizer. """ + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, base_optimizer): """Base_optimizer with decoupled weight decay. diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py index 74d1cdbbdac8724518937d141a976abf9fec6ce3..76d8a5697acb79e7748175c4a81dfdd85807dd49 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.opt.python.training import weight_decay_optimizers from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +30,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import adam -from tensorflow.contrib.opt.python.training import weight_decay_optimizers WEIGHT_DECAY = 0.01 @@ -91,7 +91,6 @@ class WeightDecayOptimizerTest(test.TestCase): opt = optimizer() update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - if not context.executing_eagerly(): with ops.Graph().as_default(): # Shouldn't return non-slot variables from other graphs. @@ -171,9 +170,9 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): @staticmethod def get_optimizer(): - AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( adam.AdamOptimizer) - return AdamW(WEIGHT_DECAY) + return adamw(WEIGHT_DECAY) def testBasic(self): self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", @@ -185,6 +184,5 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): use_resource=True) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index d538ad0fb02699ed8514f512208914f629a47436..631d4f44dfb646541244bfe1d15136dd29f02703 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -103,9 +103,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): # Non-slot variables end up on the same device(s). - state.create_non_slot(initial_value=state.get_hyper("beta1"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"), name="beta1_power") - state.create_non_slot(initial_value=state.get_hyper("beta2"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"), name="beta2_power") # Create slots for the first and second moments. diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 64b95786b5c7a71ee514201d8eb60c26975938b5..06ab58188a2fffa0e3a810d451875ca951a077b9 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -43,15 +43,15 @@ from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( + self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -88,29 +88,6 @@ class _MirroringSaveable( self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): - """A Checkpointable object which returns a more complex SaveableObject.""" - - def __init__(self): - self.non_dep_variable = variable_scope.get_variable( - name="non_dep_variable", initializer=6., use_resource=True) - self.mirrored = variable_scope.get_variable( - name="mirrored", initializer=15., use_resource=True) - - def _gather_saveables_for_checkpoint(self): - def _saveable_factory(name=self.non_dep_variable.name): - return _MirroringSaveable( - primary_variable=self.non_dep_variable, - mirrored_variable=self.mirrored, - name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} - - # The Saver sorts by name before parsing, so we need a name property. - @property - def name(self): - return self.non_dep_variable.name - - class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -122,7 +99,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -137,11 +114,11 @@ class CheckpointingTests(test.TestCase): optimizer.minimize( other_model(input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) named_variables, serialized_graph, _ = ( - checkpointable_utils._serialize_object_graph( + util._serialize_object_graph( root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. @@ -226,11 +203,11 @@ class CheckpointingTests(test.TestCase): optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -240,7 +217,7 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") @@ -266,7 +243,7 @@ class CheckpointingTests(test.TestCase): # Preserve beta1_power and beta2_power when appying gradients so we can # test that they've been restored correctly. beta1=1.0, beta2=1.0) - on_create_root = checkpointable_utils.Checkpoint( + on_create_root = util.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) @@ -298,7 +275,7 @@ class CheckpointingTests(test.TestCase): for training_continuation in range(3): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(core_saver.latest_checkpoint(checkpoint_directory)) @@ -322,7 +299,7 @@ class CheckpointingTests(test.TestCase): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) @@ -347,7 +324,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(training_continuation + 1, session.run(root.save_counter)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. @@ -359,7 +336,7 @@ class CheckpointingTests(test.TestCase): graph=ops.get_default_graph()), test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) @@ -381,7 +358,7 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: disable=cell-var-from-loop - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() @@ -392,7 +369,7 @@ class CheckpointingTests(test.TestCase): model = MyModel() # Don't actually train so we can test variable values optimizer = adam.AdamOptimizer(0.) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) @@ -442,7 +419,7 @@ class CheckpointingTests(test.TestCase): optimizer = adam.AdamOptimizer(learning_rate=0.05) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = util.Checkpoint( model=model, optimizer=optimizer) for _ in range(2): checkpoint.save(checkpoint_prefix) @@ -453,12 +430,12 @@ class CheckpointingTests(test.TestCase): optimizer.apply_gradients( [(g, v) for g, v in zip(grad, model.vars)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = checkpointable.Checkpointable() - root.var = checkpointable_utils.add_variable( + root = tracking.Checkpointable() + root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): @@ -468,28 +445,28 @@ class CheckpointingTests(test.TestCase): # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. - self.evaluate(checkpointable_utils.gather_initializers( - checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) + self.evaluate(util.gather_initializers( + util.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( + no_slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = checkpointable_utils.CheckpointableSaver(root).save( + slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = checkpointable.Checkpointable() + new_root = tracking.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = checkpointable_utils.CheckpointableSaver( + slot_status = util.CheckpointableSaver( new_root).restore(slots_path) - no_slot_status = checkpointable_utils.CheckpointableSaver( + no_slot_status = util.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() - new_root.var = checkpointable_utils.add_variable( + new_root.var = util.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() @@ -525,12 +502,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() saver.save(checkpoint_prefix) @@ -543,12 +520,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) save_path = saver.save(checkpoint_prefix) saver.restore(save_path) before_ops = graph.get_operations() @@ -565,10 +542,10 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = checkpointable_utils.Checkpoint( + first_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( first_root_checkpointable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) @@ -581,7 +558,7 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = checkpointable_utils.Checkpoint( + second_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) second_root_checkpointable.restore(None).initialize_or_restore() @@ -616,7 +593,7 @@ class CheckpointingTests(test.TestCase): class TemplateTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore(self): def _templated(): @@ -631,7 +608,7 @@ class TemplateTests(test.TestCase): save_template = template.make_template("s1", _templated) v1_save, _, v2_save = save_template() optimizer = adam.AdamOptimizer(0.0) - save_root = checkpointable_utils.Checkpoint( + save_root = util.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in optimizer.variables()]) @@ -643,7 +620,7 @@ class TemplateTests(test.TestCase): load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) - load_root = checkpointable_utils.Checkpoint( + load_root = util.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2 = load_template() @@ -664,12 +641,12 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable @@ -712,7 +689,7 @@ class CheckpointCompatibilityTests(test.TestCase): sess=session, save_path=checkpoint_prefix, global_step=root.optimizer_step) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" with test_util.device(use_gpu=True): @@ -721,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = checkpointable_utils.CheckpointableSaver(root) + object_saver = util.CheckpointableSaver(root) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f537318b32986c941b6c41eb363929e906027dd7..c6f3bd6ee18fa353944e2fc303573894933f5b27 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -162,12 +162,12 @@ def _get_processor(v): def _var_key_v2(var): """Key for representing a primary variable, for looking up slots.""" # pylint: disable=protected-access - if hasattr(var, "_mirrored_container"): - mirrored_container = var._mirrored_container() - assert mirrored_container is not None + if hasattr(var, "_distributed_container"): + distributed_container = var._distributed_container() + assert distributed_container is not None if context.executing_eagerly(): - return mirrored_container._unique_id - return mirrored_container._shared_name + return distributed_container._unique_id + return distributed_container._shared_name if context.executing_eagerly(): return var._unique_id return var.op.name @@ -211,8 +211,9 @@ class _OptimizerV2State(object): # This dict starts with a single item with key "None" with the hyper # parameter value converted to a Tensor. Other items have dtype keys # with that Tensor cast to that dtype. - self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} - for name, (dynamic, value) in hyper.items() if not dynamic} + with ops.init_scope(): + self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} + for name, (dynamic, value) in hyper.items() if not dynamic} self._slots = {} self._non_slot_dict = {} # Extra state to help Optimizers implement Checkpointable. Holds information diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index 8599af32f6f4cc5529cd812e83c02ef3812cb71e..ec033c4a0163ba9ed39e55fa9e92dfdadc9a1b2f 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -35,7 +35,7 @@ from tensorflow.python.platform import test class OptimizerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -113,7 +113,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoVariables(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: # pylint: disable=cell-var-from-loop @@ -128,7 +128,7 @@ class OptimizerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'No.*variables'): sgd_op.minimize(loss) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -146,7 +146,7 @@ class OptimizerTest(test.TestCase): # var1 has no gradient sgd_op.minimize(loss, var_list=[var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -162,7 +162,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.minimize(loss, var_list=[var0, var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -176,7 +176,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.apply_gradients([(None, var0), (None, var1)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -216,7 +216,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): x = ops.convert_to_tensor(1.0) def f(): diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 976b312e8345a801ad07f622b6117b88af2cf603..f2171efc959362c1e4392fefbd5842f0883571d7 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -97,6 +97,8 @@ tf_cc_test( ], deps = [ ":all_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc index 55edf76fcd3eed461e1465b569e1c2e9e2facbc0..43b7c1799ffb2e27f9d15bc6011d49334867b6ec 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 55479bf5f74299bf09f131a6127f9f11d6192d90..e3c48998305e9d9b6c185fd4c0f324fa0449c691 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -121,7 +121,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): scaled_weight_tensor = math_ops.multiply( weights, multiplier_tensor, name='mul_fold') new_layer_tensor = _CloneWithNewOperands( - match.layer_op, match.input_tensor, scaled_weight_tensor) + match.layer_op, match.input_tensor, scaled_weight_tensor, + match.batch_to_space_op) if correction_recip is not None: new_layer_tensor = math_ops.multiply( @@ -149,6 +150,8 @@ def _FindFusedBatchNorms(graph): _FusedBatchNormMatches. """ input_pattern = graph_matcher.OpTypePattern('*') + # In practice, the weight pattern can match a Variable or a SpaceToBatchND + # operation that follows a variable for atrous convolutions. weight_pattern = graph_matcher.OpTypePattern('*') gamma_pattern = graph_matcher.OpTypePattern('*') beta_pattern = graph_matcher.OpTypePattern('*') @@ -160,16 +163,27 @@ def _FindFusedBatchNorms(graph): layer_pattern = graph_matcher.OpTypePattern( 'Conv2D|DepthwiseConv2dNative|MatMul', inputs=[input_pattern, weight_pattern]) + batch_to_space_pattern = graph_matcher.OpTypePattern( + 'BatchToSpaceND', + inputs=[ + layer_pattern, + graph_matcher.OpTypePattern('*'), + graph_matcher.OpTypePattern('*') + ]) + layer_output_pattern = graph_matcher.OneofPattern( + [layer_pattern, batch_to_space_pattern]) # MatMul has a Reshape between it and FusedBatchNorm. matmul_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', inputs=[layer_pattern, - graph_matcher.OpTypePattern('*')]) + 'Reshape', + inputs=[layer_output_pattern, + graph_matcher.OpTypePattern('*')]) batch_norm_pattern = graph_matcher.OpTypePattern( 'FusedBatchNorm', inputs=[ - graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]), - gamma_pattern, beta_pattern, mean_pattern, variance_pattern + graph_matcher.OneofPattern( + [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern, + beta_pattern, mean_pattern, variance_pattern ]) matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( 'Reshape', inputs=[batch_norm_pattern, @@ -192,6 +206,7 @@ def _FindFusedBatchNorms(graph): moving_variance_tensor = None bn_decay_mean_tensor = None bn_decay_var_tensor = None + batch_to_space_op = None layer_op = match_result.get_op(layer_pattern) layer_tensor = match_result.get_tensor(layer_pattern) bn_op = match_result.get_op(batch_norm_pattern) @@ -213,6 +228,7 @@ def _FindFusedBatchNorms(graph): if not output_tensor.consumers(): continue + batch_to_space_op = match_result.get_op(batch_to_space_pattern) input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) @@ -276,7 +292,8 @@ def _FindFusedBatchNorms(graph): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon=batch_epsilon) + batch_epsilon=batch_epsilon, + batch_to_space_op=batch_to_space_op) def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, @@ -380,7 +397,8 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, return correction_scale, correction_recip, correction_offset -def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, + batch_to_space_op): """Clones layer_op with input_tensor and weight_tensor as new inputs.""" new_layer_name = layer_op.name.split('/')[-1] + '_Fold' if layer_op.type == 'Conv2D': @@ -400,12 +418,25 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): transpose_b=layer_op.get_attr('transpose_b'), name=new_layer_name) elif layer_op.type == 'DepthwiseConv2dNative': - return nn.depthwise_conv2d( + conv = nn.depthwise_conv2d( input_tensor, weight_tensor, + rate=layer_op.get_attr('dilations'), strides=layer_op.get_attr('strides'), padding=layer_op.get_attr('padding'), name=new_layer_name) + # Copy the batch to space operation if we have a atrous convolution. + if batch_to_space_op: + batch_to_space_op = layer_op.outputs[0].consumers()[0] + # TODO(suharshs): It's hard to make this name match with the unfused name. + # Restructure this code to not rely on scope at all. + new_batch_to_space_name = batch_to_space_op.name.split('/')[-1] + '_Fold' + conv = array_ops.batch_to_space_nd( + conv, + batch_to_space_op.inputs[1], + batch_to_space_op.inputs[2], + name=new_batch_to_space_name) + return conv else: raise ValueError('Cannot handle operation of type: %s' % layer_op.type) @@ -617,7 +648,8 @@ def _GetBatchNormParams(graph, context, has_scaling): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon=batch_epsilon) + batch_epsilon=batch_epsilon, + batch_to_space_op=None) def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, @@ -651,6 +683,11 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, '/BatchNorm/batchnorm_1/' + mul_scale_name) op_below = mul_scale.inputs[0].op + # Skip over the BatchToSpace operation in the case of atrous convolutions. + batch_to_space_op = None + if op_below.type == 'BatchToSpaceND': + batch_to_space_op = op_below + op_below = op_below.inputs[0].op weights = op_below.inputs[1] match = _GetBatchNormParams( graph=graph, context=context, has_scaling=has_scaling) @@ -691,7 +728,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, context + '/correction_mult') mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) else: - raise ValueError('Cannot handle operation of type: %s' % op_below.op) + raise ValueError('Cannot handle operation of type: %s' % op_below.type) _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0]) conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', @@ -701,6 +738,13 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, context + '/BatchNorm/batchnorm_1/add_1') corrected_output = conv_or_fc_folded.outputs[0] + # Copy the batch to space operation if we have a atrous convolution. + if batch_to_space_op: + corrected_output = array_ops.batch_to_space_nd( + corrected_output, + batch_to_space_op.inputs[1], + batch_to_space_op.inputs[2], + name=batch_to_space_op.name + '_Fold') if correction_offset is not None: with ops.device(conv_or_fc_folded.device): corrected_output = math_ops.multiply(correction_recip, corrected_output, @@ -898,7 +942,8 @@ class _BatchNormMatch(object): def __init__(self, layer_op, bn_op, output_tensor, input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, variance_tensor, moving_mean_tensor, moving_variance_tensor, - bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon): + bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon, + batch_to_space_op): self._layer_op = layer_op self._bn_op = bn_op self._output_tensor = output_tensor @@ -913,6 +958,7 @@ class _BatchNormMatch(object): self._bn_decay_mean_tensor = bn_decay_mean_tensor self._bn_decay_var_tensor = bn_decay_var_tensor self._batch_epsilon = batch_epsilon + self._batch_to_space_op = batch_to_space_op @property def layer_op(self): @@ -969,3 +1015,7 @@ class _BatchNormMatch(object): @property def bn_decay_var_tensor(self): return self._bn_decay_var_tensor + + @property + def batch_to_space_op(self): + return self._batch_to_space_op diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index bfa9d3bf705e327091098a8e416b7902f852605a..7c907ffd92c1ae0c762e41cc429b0e6ce053f6b9 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -438,6 +438,90 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def testFoldDepthwiseConv2d(self): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) + def _TestFoldAtrousConv2d(self, relu, relu_op_name, with_bypass, has_scaling, + fused_batch_norm, freeze_batch_norm_delay): + """Tests folding: inputs -> AtrousConv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + dilation_rate = 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + if fused_batch_norm: + scale_reshape_op_name = scope + '/BatchNorm_Fold/scale_reshape' + else: + scale_reshape_op_name = scope + '/scale_reshape' + self._AssertInputOpsAre(folded_mul, + [scope + '/correction_mult', scale_reshape_op_name]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) + + scale_reshape = g.get_operation_by_name(scale_reshape_op_name) + self.assertEqual(scale_reshape.type, 'Reshape') + self._AssertInputOpsAre(scale_reshape, [ + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm), + scale_reshape_op_name + '/shape' + ]) + self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) + + folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') + self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') + self._AssertInputOpsAre( + folded_conv, [scope + '/mul_fold', scope + '/depthwise/SpaceToBatchND']) + if fused_batch_norm: + self._AssertOutputGoesToOps(folded_conv, g, + [scope + '/BatchToSpaceND_Fold']) + else: + self._AssertOutputGoesToOps(folded_conv, g, + [scope + '/depthwise/BatchToSpaceND_Fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/correction_add', + self._BathNormBiasName(scope, fused_batch_norm) + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + for op in g.get_operations(): + self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) + + def testFoldAtrousConv2d(self): + self._RunTestOverParameters(self._TestFoldAtrousConv2d) + def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, freeze_batch_norm_delay): diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index cbba72643f7f166c473b6181edc292f695c4cbc2..19e5bef1ea48ca4441cdef6b1a74e98e9cf6ddb9 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -194,6 +194,8 @@ def _FindLayersToQuantize(graph): / conv|fc | + [batch_to_space_nd] + | [post_conv_correction] | biasadd|folded_bias @@ -247,9 +249,21 @@ def _FindLayersToQuantize(graph): ], ordered_inputs=False) + # For atrous convolutions a BatchToSpaceND will occur after the depthwise + # convolution. + batch_to_space_pattern = graph_matcher.OpTypePattern( + 'BatchToSpaceND', + inputs=[ + layer_pattern, + graph_matcher.OpTypePattern('*'), + graph_matcher.OpTypePattern('*') + ]) + + layer_output_pattern = graph_matcher.OneofPattern( + [batch_to_space_pattern, layer_pattern]) folded_bias_mul_pattern = graph_matcher.OpTypePattern( 'Mul', - inputs=[graph_matcher.OpTypePattern('*'), layer_pattern], + inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern], ordered_inputs=False) post_layer_op_correction_pattern = graph_matcher.OpTypePattern( 'Add', @@ -265,7 +279,7 @@ def _FindLayersToQuantize(graph): ordered_inputs=False) bias_add_pattern = graph_matcher.OpTypePattern( - 'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False) + 'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False) # The bias can come from the bias add or the folded bias add. bypass_pattern = graph_matcher.OpTypePattern( @@ -373,14 +387,6 @@ def _FindLayersToQuantize(graph): return layer_matches -def _HasPostActivationBypass(activation_op): - for activation_tensor in activation_op.outputs: - for output_op in activation_tensor.consumers(): - if output_op.type == 'Add': - return True - return False - - class _LayerMatch(object): """Contains all information related to a matched Layer.""" diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index db745aa56212af6a9c20e06ee9e4e5d6e27cf3c3..5e3af0a567536ef6fcfd86d82e94c0ba21077a85 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -276,6 +276,52 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, delay, use_resource) + def testQuantize_AtrousConvWithoutBatchNorm(self): + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_AtrousConvWithoutBatchNorm) + + def _TestQuantize_AtrousConvWithoutBatchNorm( + self, activation, activation_op_name, with_bypass, delay, use_resource): + """Tests quantization: inputs -> atrous conv no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + use_resource: Bool, when true uses resource variables. + """ + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + dilation_rate = 2 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + quantize.Quantize(graph, True, quant_delay=delay) + + self._AssertCorrectQuantizedGraphWithoutBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, + delay, use_resource) + def _RunBatchNormTestOverParameters(self, test_fn): # TODO(suharshs): Use parameterized test once OSS TF supports it. parameters_list = [ @@ -543,6 +589,61 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, delay, use_resource) + def testQuantize_AtrousConvWithBatchNorm(self): + self._RunBatchNormTestOverParameters( + self._TestQuantize_AtrousConvWithBatchNorm) + + def _TestQuantize_AtrousConvWithBatchNorm( + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_resource): + """Tests quantization: inputs -> atrous conv with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + use_resource: Bool, when true uses resource variables. + """ + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + dilation_rate = 2 + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + + # Manually add a bypass (optional) and an activation. + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + fold_batch_norms.FoldBatchNorms(graph, is_training=True) + quantize.Quantize(graph, True, quant_delay=delay) + + self._AssertCorrectQuantizedGraphWithBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, + with_bypass, delay, use_resource) + def _AssertIdempotent(self, graph): # Ensure that calling the rewrite again doesn't change the graph. graph_def_before = str(graph.as_graph_def()) diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD index b3cb04ce26d96333f516f1298c8d5c331964f05b..f9827f766da022b184b3348fc24b1570bac8678f 100644 --- a/tensorflow/contrib/recurrent/BUILD +++ b/tensorflow/contrib/recurrent/BUILD @@ -102,5 +102,8 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["nopip"], + tags = [ + "nopip", + "optonly", + ], ) diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 67f31785b57fddef67733c18c3b744322532c28c..07227bcb77d353200ee46763d51727ed9c0974a1 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -58,6 +58,7 @@ See @{$python/contrib.rnn} guide. @@Conv3DLSTMCell @@HighwayWrapper @@GLSTMCell +@@SRUCell @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index b8840a8f2420f1bc6c75f0a02e5465c595378dec..86f1e27abd53d011f37f06851dd6d0977853c8f4 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -443,7 +443,7 @@ class RNNCellTest(test.TestCase): self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWrapperCheckpointing(self): for wrapper_type in [ rnn_cell_impl.DropoutWrapper, diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index be99a5d67a3e49b1d522406601d050392f75e963..1c20d88fe4bcbe2c1f1e3413502dbf276f2d21b3 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -921,7 +921,7 @@ class LSTMTest(test.TestCase): # Smoke test, this should not raise an error rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicRNNWithTupleStates(self): num_units = 3 input_size = 5 @@ -997,7 +997,7 @@ class LSTMTest(test.TestCase): self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicRNNWithNestedTupleStates(self): num_units = 3 input_size = 5 @@ -1285,7 +1285,7 @@ class LSTMTest(test.TestCase): "Comparing individual variable gradients iteration %d" % i) self.assertAllEqual(a, b) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicEquivalentToStaticRNN(self): self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index e69725ff8ab1ba4de880c914a6f5fdad5e54566d..f58268eff525a4b592c79acb32207e1a3f62bdc7 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc import six +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -182,19 +183,20 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) - def _is_xla_tensor(tensor): - try: - op = tensor.op - except AttributeError: - return False - if control_flow_util.IsInXLAContext(op): - return True - return False - with variable_scope.variable_scope(scope, "decoder") as varscope: - # Properly cache variable values inside the while_loop - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + # Determine context types. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None + in_while_loop = ( + control_flow_util.GetContainingWhileContext(ctxt) is not None) + # Properly cache variable values inside the while_loop. + # Don't set a caching device when running in a loop, since it is possible + # that train steps could be wrapped in a tf.while_loop. In that scenario + # caching prevents forward computations in loop iterations from re-reading + # the updated weights. + if not context.executing_eagerly() and not in_while_loop: + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( @@ -208,9 +210,6 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) - is_xla = False - if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): - is_xla = True if is_xla and maximum_iterations is None: raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index 03d6da7765ba5249a9fb22f56a469cf07c310479..f10d78259a3be3a3a6f7f78c196ab107f18a53aa 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -147,7 +147,7 @@ class SpectralOpsTest(test.TestCase): inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, fft_length=16, frame_step=8) expected_length = (stft.shape[0] - 1) * 8 + 8 - self.assertAllEqual([None], inverse_stft.shape.as_list()) + self.assertAllEqual([256], inverse_stft.shape.as_list()) self.assertAllEqual([expected_length], inverse_stft.eval().shape) def test_stft_and_inverse_stft(self): diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py index 9a3603b6a97ef7c3a4b940b83281ebceda93c9db..7d6289532addfd4b4b867bf64d9113253bd1c76d 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/test_util.py +++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py @@ -39,6 +39,7 @@ def grappler_optimize(graph, fetches=None, rewriter_config=None): """ if rewriter_config is None: rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 if fetches is not None: for fetch in fetches: graph.add_to_collection('train_op', fetch) diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 3d0308aaf3da3b5b16fd22a2905db36917e8c97b..2c97834523424d0fab56330b4d9355a75427e0ef 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase): checkpoint_path = os.path.join(self.get_temp_dir(), 'this_file_doesnt_exist') log_dir = os.path.join(self.get_temp_dir(), 'error_raised') - with self.assertRaises(errors.NotFoundError): + with self.assertRaises(ValueError): evaluation.evaluate_once('', checkpoint_path, log_dir) def _prepareCheckpoint(self, checkpoint_path): diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index f1ef218e74bbd225071324a8269fdfeb5de0e038..3e41e3d0b48ea06f9cb8c1862e27eacb5ebc4417 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -81,6 +81,19 @@ class EagerFileTest(test_util.TensorFlowTestCase): # test here that we're calling them correctly. self.assertTrue(gfile.Exists(logdir)) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerMemory(self): + training_util.get_or_create_global_step() + logdir = self.get_temp_dir() + with summary_ops.create_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index a5d8b061b6b26f9d05be40a1162481ae219b0e9c..adda0b758b172f5e80c165e4b28dbdbecef2ba16 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -49,7 +49,6 @@ tf_cuda_cc_test( tf_custom_op_library( name = "python/ops/_trt_engine_op.so", srcs = [ - "ops/trt_calib_op.cc", "ops/trt_engine_op.cc", ], deps = [ @@ -76,11 +75,9 @@ tf_cuda_library( cc_library( name = "trt_engine_op_kernel", srcs = [ - "kernels/trt_calib_op.cc", "kernels/trt_engine_op.cc", ], hdrs = [ - "kernels/trt_calib_op.h", "kernels/trt_engine_op.h", ], copts = tf_copts(), @@ -89,20 +86,22 @@ cc_library( ":trt_logging", ":trt_plugins", ":trt_resources", + ":trt_conversion", + ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd) + # TODO(laigd): fix this by merging header file in cc file. alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs ) tf_gen_op_libs( op_lib_names = [ "trt_engine_op", - "trt_calib_op", ], ) @@ -122,7 +121,6 @@ tf_gen_op_wrapper_py( name = "trt_engine_op", gen_locally = True, deps = [ - ":trt_calib_op_op_lib", ":trt_engine_op_op_lib", ":trt_logging", ":trt_shape_function", @@ -140,7 +138,6 @@ tf_custom_op_py_library( kernels = [ ":trt_engine_op_kernel", ":trt_engine_op_op_lib", - ":trt_calib_op_op_lib", ":trt_shape_function", ], srcs_version = "PY2AND3", @@ -191,7 +188,6 @@ tf_py_wrap_cc( deps = [ ":trt_conversion", ":trt_engine_op_kernel", - "//tensorflow/core:framework_lite", "//third_party/python_runtime:headers", ], ) @@ -211,6 +207,7 @@ tf_cuda_library( ], deps = [ ":trt_logging", + ":utils", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", "//tensorflow/core:lib_proto_parsing", @@ -237,12 +234,12 @@ tf_cuda_library( ":trt_plugins", ":trt_logging", ":trt_resources", + ":utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", "//tensorflow/core:framework_lite", "//tensorflow/core:graph", @@ -343,3 +340,8 @@ py_test( "//tensorflow/python:framework_test_lib", ], ) + +cc_library( + name = "utils", + hdrs = ["convert/utils.h"], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index da4dd5a14cd74591fc9df63cd5868044e4e369ec..4dc1c551cc585cbfd0bacdce843b1eae82f5054e 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include #include #include #include @@ -24,10 +24,17 @@ limitations under the License. #include #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -39,17 +46,39 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT +#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT +#include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" #include "tensorrt/include/NvInfer.h" - namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// Returns compiled TRT version information {Maj, Min, Patch} +std::vector GetLinkedTensorRTVersion() { + return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH}; +} + +// Returns loaded TRT library version {Maj, Min, Patch} +std::vector GetLoadedTensorRTVersion() { + int ver = getInferLibVersion(); + int ver_major = ver / 1000; + ver = ver - ver_major * 1000; + int ver_minor = ver / 100; + int ver_patch = ver - ver_minor * 100; + return {ver_major, ver_minor, ver_patch}; +} + namespace { bool IsTensorRTCandidate(const tensorflow::Node* node) { @@ -82,229 +111,6 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } -void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* incoming_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->in_edges()) { - if (!subgraph_node_ids.count(edge->src()->id()) && - !edge->src()->IsSource() && !edge->IsControlEdge()) { - incoming_edges->insert(edge); - VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() - << " Y, "; - } else { - VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() - << " N, "; - } - } - } -} - -void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* outgoing_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->out_edges()) { - if (!subgraph_node_ids.count(edge->dst()->id()) && - !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() - << " Y, "; - outgoing_edges->insert(edge); - } else { - VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() - << " N, "; - } - } - } -} - -std::pair ParseTensorName(const string& name, - int default_idx = 0) { - string name_no_idx = name; - int idx = default_idx; - const size_t sep = name_no_idx.find_last_of(':'); - if (sep != string::npos) { - name_no_idx = name_no_idx.substr(0, sep); - idx = std::stoi(name.substr(sep + 1)); - } - return std::make_pair(name_no_idx, idx); -} - -std::unordered_map> BuildTensorNameMap( - const std::vector& tensor_names) { - std::unordered_map> result; - for (const string& tensor_name : tensor_names) { - string node_name; - int index; - std::tie(node_name, index) = ParseTensorName(tensor_name); - result[node_name].push_back(index); - } - return result; -} - -// TODO(sami): convert references to pointers -struct ConvertGraphParams { - ConvertGraphParams( - tensorflow::Graph& inp_graph, - const std::vector& output_node_names, - const std::set& subgraph_node_id_numbers, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - int engine_precision_mode, const string& device_name, - std::shared_ptr allocator, int cuda_gpu_id) - : graph(inp_graph), - output_names(output_node_names), - subgraph_node_ids(subgraph_node_id_numbers), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - tensorflow::Graph& graph; - const std::vector& output_names; - const std::set& subgraph_node_ids; - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - int precision_mode; - string device_name_; - std::shared_ptr allocator_; - int cuda_gpu_id_; - std::vector> subgraph_inputs; - std::vector> subgraph_outputs; - tensorflow::EdgeSet subgraph_incoming_edges; - tensorflow::EdgeSet subgraph_outgoing_edges; -}; - -static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { - GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_incoming_edges); - - std::set> unique_tensors; - // Add only unique input source nodes. If output of an outside node is shared - // between multiple nodes inside the engine, only one edge should be created - for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - unique_tensors.insert({edge->src()->id(), edge->src_output()}); - } - p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(), - unique_tensors.end()); - GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_outgoing_edges); - unique_tensors.clear(); - // Similar to above, if multiple ouside nodes are sharing the output of an - // internal node only one output port should be created and shared between - // outputs - for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - unique_tensors.insert({edge->src()->id(), edge->src_output()}); - } - p->subgraph_outputs.reserve(unique_tensors.size()); - p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - unique_tensors.begin(), unique_tensors.end()); - return tensorflow::Status::OK(); -} - -tensorflow::Status GetCalibNode(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(InjectCalibrationNode(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - TF_RETURN_IF_ERROR(status); - - for (auto in_edge : - params->subgraph_incoming_edges) { // loop over incoming edges and - // attach them to calib node - auto src_output = in_edge->src_output(); - auto dst_node = in_edge->dst(); - auto dst_input = in_edge->dst_input(); - VLOG(1) << " update edge " << trt_node->name() << ":" << src_output - << " -> " << dst_node->name() << ":" << dst_input; - TF_RETURN_IF_ERROR( - params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input)); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - // AddNode does not wire edges. - // Re-map incoming edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_input_map; - for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { - subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); - } - std::set> unique_tensors; - for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - if (unique_tensors.count(old_src)) continue; - unique_tensors.insert(old_src); - int new_src_output = subgraph_edge_to_input_map.at(old_src); - params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, - new_src_output); - VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output() - << " -> " << trt_node->name() << ":" << new_src_output; - params->graph.RemoveEdge(edge); - } - if (VLOG_IS_ON(2)) { - VLOG(2) << "new edge count: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); - } - } - TF_RETURN_IF_ERROR(status); - - // Re-map outgoing edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_output_map; - for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) { - subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i}); - } - TF_RETURN_IF_ERROR(status); - for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - int new_src_output = subgraph_edge_to_output_map.at(old_src); - TF_RETURN_IF_ERROR(params->graph.UpdateEdge( - trt_node, new_src_output, edge->dst(), edge->dst_input())); - VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> " - << edge->dst()->name() << ":" << edge->dst_input(); - } - // Remove the original subgraph - for (int node_id : params->subgraph_node_ids) { - tensorflow::Node* node = params->graph.FindNodeId(node_id); - // Don't remove the input placeholders - if (node->type_string() == "Placeholder") { - continue; - } - params->graph.RemoveNode(node); - } - return tensorflow::Status::OK(); -} - tensorflow::Status BuildNodeMap( const tensorflow::Graph& graph, std::unordered_map* node_map) { @@ -318,51 +124,77 @@ tensorflow::Status BuildNodeMap( } } // namespace + +// Function to get calibration from ResourceMgr and put them into nodedef. tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) { + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, + bool is_dyn_op) { VLOG(0) << "Starting Calib Conversion"; - tensorflow::Graph graph(tensorflow::OpRegistry::Global()); - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), graph_def, &graph)); - // get calib nodes - std::vector calib_nodes; - std::vector topo_order; - tensorflow::GetPostOrder(graph, &topo_order); - for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { - auto node = *rit; - if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node " << node->name(); - calib_nodes.push_back(node); - } + infer_graph->CopyFrom(graph_def); + auto trt_rm = TRTResourceManager::instance(); + auto calib_rm = trt_rm->getManager("TRTCalibration"); + int num_nodes = infer_graph->node_size(); + if (!is_dyn_op) { + LOG(WARNING) << "Construction of static int8 engine is not implemented " + "yet!. Dynamic engine will be constructed"; } - VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size(); - if (calib_nodes.size() == 0) - return tensorflow::errors::FailedPrecondition( - "Graph doesn't contain any calibration nodes!." - " Please generate calibration graph and run calibration first"); - for (auto n : calib_nodes) { - TF_RETURN_IF_ERROR( - tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n)); + for (int i = 0; i < num_nodes; ++i) { + auto n = infer_graph->mutable_node(i); + if (n->op() == "TRTEngineOp") { + VLOG(1) << "Processing " << n->name(); + string container_name = n->attr().at("segment_funcdef_name").s(); + TRTCalibrationResource* cres = nullptr; + auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); + if (!status.ok()) { + LOG(ERROR) << "Could not get Calibration information. Did you run with " + "calibration data?"; + return tensorflow::errors::FailedPrecondition( + "Need to run graph with calibration data first!"); + } + if (cres->calibrator_) { + cres->calibrator_->setDone(); + cres->thr_->join(); + const auto& calibration_table = + cres->calibrator_->getCalibrationTableAsString(); + if (!calibration_table.size()) { + LOG(ERROR) << "Calibration table is empty"; + return tensorflow::errors::Unknown( + "Calibration table is missing. This shouldn't have happened!"); + } + n->mutable_attr()->at("calibration_data").set_s(calibration_table); + } else { + LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; + return tensorflow::errors::Unknown( + "Can't get TRTCalibrator from resource manager!"); + } + cres->Unref(); + } } - graph.ToGraphDef(infer_graph); return tensorflow::Status::OK(); } +// Entry function from Python. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode = FP32MODE, int minimum_segment_size = 3) { + int precision_mode, int minimum_segment_size, bool is_dyn_op, + int max_cached_engines, std::vector cached_engine_batches) { // optimization pass tensorflow::grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; - + // grappler requires a virtual cluster with a proper GPU device + // in order to calculate flops>0 or fails with FATAL + // We add numbers from a Pascal card here to have flops>0 tensorflow::DeviceProperties device_properties; device_properties.set_type("GPU"); device_properties.mutable_environment()->insert({"architecture", "6"}); - tensorflow::grappler::Cluster* cluster = - new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); + device_properties.set_num_cores(3584); + device_properties.set_frequency(1531); + std::unique_ptr cluster( + new tensorflow::grappler::VirtualCluster( + {{"/GPU:0", device_properties}})); // single machine int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); @@ -370,134 +202,633 @@ tensorflow::Status ConvertGraphDefToTensorRT( VLOG(2) << "cpu_cores: " << num_cpu_cores; VLOG(2) << "gpus: " << num_gpus; tensorflow::RewriterConfig rw_cfg; + // use only const folding and layout for the time being since new optimizers + // break the graph for us + rw_cfg.add_optimizers("constfold"); + rw_cfg.add_optimizers("layout"); + rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE); tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); tensorflow::GraphDef gdef; - TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster, item, &gdef)); + TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef)); item.graph = gdef; // AJ refactoring shape inference through grappler/GraphProperties. tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); // Build full graph - - return ConvertAfterShapes(gdef, output_names, max_batch_size, - max_workspace_size_bytes, new_graph_def, - precision_mode, minimum_segment_size, - static_graph_properties, nullptr); + ConversionParams cp; + cp.input_graph_def = &gdef; + cp.output_names = &output_names; + cp.max_batch_size = max_batch_size; + cp.output_graph_def = new_graph_def; + cp.precision_mode = precision_mode; + cp.is_dyn_op = is_dyn_op; + cp.max_cached_engines = max_cached_engines; + cp.cached_engine_batches = cached_engine_batches; + cp.minimum_segment_size = minimum_segment_size; + cp.graph_properties = &static_graph_properties; + cp.max_workspace_size_bytes = max_workspace_size_bytes; + if (VLOG_IS_ON(5)) { + std::fstream f; + f.open("TRTConversionInput.pb", + std::fstream::out | std::fstream::binary | std::fstream::trunc); + f << gdef.SerializeAsString(); + f.close(); + } + return ConvertAfterShapes(cp); } -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& gdef, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, +// Function to get subsegment information structure. +tensorflow::Status GetEngineInfo( + const tensorflow::Graph* g, const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster) { - // Segment the graph into subgraphs that can be converted to TensorRT - tensorflow::tensorrt::segment::SegmentOptions segment_options; + const std::set& segment_nodes, + const std::unordered_map& node_map, + const std::vector& reverse_topo_order, + EngineInfo* info) { + std::vector subgraph_node_ids; + std::set segment_devices; + int input_port = 0; + int output_port = 0; + + // Map from src_node_name+port to the unique port numbers of the TRT op, where + // the src_node_name is the name of the source node of the input/output + // edge, thus there must not be any duplicates since source nodes of + // input/output edges must be in different split of the graph. + // TODO(aaroey): consider using node id and port instead. + std::unordered_map created_edges; + for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); + ++it) { + const auto& node_name = (*it)->name(); + + if (segment_nodes.count(node_name) == 0) continue; + auto node = node_map.at(node_name); + auto node_device = node->requested_device(); + if (!node_device.empty()) { + segment_devices.insert(node_device); + } else { + if (node->has_assigned_device_name()) { + segment_devices.insert(node->assigned_device_name()); + } else { + VLOG(2) << "Node " << node->name() + << " neither have requested device nor assigned device"; + } + } + int node_id = node->id(); + subgraph_node_ids.push_back(node_id); + for (const auto edge : node->in_edges()) { + auto input_node = edge->src(); + if (segment_nodes.count(input_node->name()) == 0) { + // Add constant input node into the segment. We don't care if it has + // other output edges going into other engines or TF nodes. Since we add + // it only to the subsegment node list, not the subsegment itself, it + // won't be removed from the graph. If it doesn't have any edges, TF + // will prune it out. + if (input_node->type_string() == "Const") { + subgraph_node_ids.push_back(input_node->id()); + } else if (!edge->IsControlEdge() && !input_node->IsSource()) { + string s(input_node->name()); + StrAppend(&s, ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + int port = input_port; + if (created_edges.count(s)) { + port = created_edges.at(s); + } else { + created_edges.insert({s, port}); + input_port++; + } + info->connections.emplace_back(input_node->name(), input_node->id(), + edge->src_output(), node_name, node_id, + edge->dst_input(), true, port); + } + } + } + for (const auto edge : node->out_edges()) { + auto output_node = edge->dst(); + if (segment_nodes.count(output_node->name()) == 0 && + !edge->IsControlEdge() && !output_node->IsSink()) { + string s(node_name); + StrAppend(&s, ":", edge->src_output()); + VLOG(1) << "Output edge = " << s; + int port = output_port; + if (created_edges.count(s)) { + port = created_edges.at(s); + } else { + created_edges.insert({s, port}); + output_port++; + } + info->connections.emplace_back(output_node->name(), output_node->id(), + edge->dst_input(), node_name, node_id, + edge->src_output(), false, port); + } + } + } + + TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( + g, graph_properties, subgraph_node_ids, &info->connections, + &info->segment_graph_def, &info->engine_name)); + // TODO(sami): This should not happen once segmenter is updated. + if (segment_devices.size() == 1) { + info->device = *segment_devices.begin(); + } else if (segment_devices.size() > 1) { + LOG(WARNING) << "Detected multiple(" << segment_devices.size() + << ") devices for the segment. Picking first one to continue " + << "but this shouldn't have happened"; + info->device = *segment_devices.begin(); + } else { + VLOG(1) << "Segment devices size is 0"; + } + return Status::OK(); +} + +// Function to insert a TRT node into the graph. The graph is not modified if +// the returned status is not ok. +// 'alloc' is only used for creating static engine. +tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, + const std::vector& infos, int pos, + nvinfer1::IGpuAllocator* alloc, + int max_batch_size) { + const auto& info = infos.at(pos); + std::vector out_shapes; + std::vector input_shapes; + std::vector shapes; + std::vector inputs; + std::vector out_types; + VLOG(1) << "Processing " << info.engine_name; + + // Update the shape and data types of input/output nodes, and find all unique + // inputs. + for (const auto& conn : info.connections) { + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (out_shapes.size() <= conn.port_number) { + out_shapes.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + out_shapes.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + continue; + } + + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shapes.size() <= conn.port_number) { + input_shapes.resize(conn.port_number + 1); + shapes.resize(conn.port_number + 1); + } + input_shapes.at(conn.port_number) = in_shape; + shapes.at(conn.port_number) = conn.outside_shape; + + string input_node = conn.outside_node_name; + int input_port = conn.outside_port; + bool found_engine = false; + // Rewire the inputs to other engines if they contain original input node. + // Note that we use the information of the engine here, not the information + // of the created TRT nodes, so we're able to find all the connections to + // any other engines beforehand. + for (size_t t = 0; t < infos.size(); ++t) { + if (t == pos) continue; + auto& engine_info = infos.at(t); + for (const auto& eng_conn : engine_info.connections) { + if (eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == input_node) { + input_node = engine_info.engine_name; + if (eng_conn.inside_port == input_port) { + input_port = eng_conn.port_number; + found_engine = true; + break; + } + } + } + if (found_engine) break; + } + VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " + << info.engine_name << ":" << inputs.size(); + // Skip duplicate inputs. + bool new_input = true; + for (const auto& inp : inputs) { + if (inp.node == input_node && inp.index == input_port) { + new_input = false; + break; + } + } + if (new_input) { + inputs.emplace_back(input_node, input_port, conn.connection_type); + } + } + + // Build the engine and get its serialized representation. + string segment_string; + if (info.engine_type == EngineInfo::EngineType::TRTStatic || + info.precision_mode == INT8MODE) { + // Create static engine for fp32/fp16 mode, and test validity of the engine + // for int8 mode. We don't want engine to fail at the calibration time. + // So we are constructing a FP32 engine here to check its validity, and if + // it is a valid engine then we put the serialized graphdef to the op. + // Otherwise we skip node creation for this engine. + Logger trt_logger; + TrtUniquePtrType engine; + // TODO(sami): What happens if 1st dim is not batch? + TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( + info.segment_graph_def, + info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, + max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger, + alloc, /*calibrator=*/nullptr, &engine, + /*convert_successfully=*/nullptr)); + TrtUniquePtrType engine_data(engine->serialize()); + segment_string = + string((const char*)engine_data->data(), engine_data->size()); + if (info.precision_mode == INT8MODE) { + // See above comment about why not putting this inside the 'else' branch. + segment_string = info.segment_graph_def.SerializeAsString(); + } + } else { + segment_string = info.segment_graph_def.SerializeAsString(); + } + + // TODO(aaroey): use enum instead, and add a helper method to do the + // conversion. + string prec_string; + switch (info.precision_mode) { + case FP32MODE: + prec_string = "FP32"; + break; + case FP16MODE: + prec_string = "FP16"; + break; + case INT8MODE: + prec_string = "INT8"; + if (!TRTResourceManager::instance()->getManager("TRTCalibration")) { + LOG(ERROR) << "Failed to construct calibration storage"; + } + break; + default: + return tensorflow::errors::OutOfRange("Unknown precision mode"); + } + tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); + if (!info.device.empty()) node_builder.Device(info.device); + if (VLOG_IS_ON(1)) { + string ins = StrCat(info.engine_name, " inputs= "); + for (const auto& ii : inputs) { + StrAppend(&ins, ii.node, ":", ii.index, " "); + } + VLOG(1) << ins; + } + node_builder.Input(inputs); + if (info.engine_type == EngineInfo::EngineType::TRTStatic && + info.cached_engine_batches.size()) { + LOG(WARNING) << "Cached engine batches are ignored for static engines"; + } + tensorflow::NodeDef trt_node; + tensorflow::Status status = + node_builder.Attr("input_shapes", input_shapes) + .Attr("output_shapes", out_shapes) + .Attr("static_engine", + info.engine_type == EngineInfo::EngineType::TRTStatic) + .Attr("segment_funcdef_name", + StrCat(info.engine_name, "_native_segment")) + .Attr("serialized_segment", segment_string) + .Attr("calibration_data", "") + .Attr("max_cached_engines_count", info.maximum_cached_engines) + .Attr("cached_engine_batches", {max_batch_size}) + .Attr("workspace_size_bytes", info.max_workspace_size_bytes) + .Attr("precision_mode", prec_string) + .Attr("OutT", out_types) + .Finalize(&trt_node); + if (!status.ok()) { + LOG(ERROR) << "Node construction failed with" << status; + return status; + } + VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph"; + + // Up until this point, graph is not modified. If we return !status.ok() from + // here, this segment will be skipped + tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + if (!status.ok()) { + LOG(ERROR) << "Adding node failed " << status; + return status; + } + // Updates the inputs of output edges destination nodes, and point them to the + // engine node. + for (auto& conn : info.connections) { + if (conn.is_input_edge) continue; + VLOG(1) << " Updating DBG " << engine_node->name() << " out_port " + << conn.port_number << " out_id " << conn.outside_id + << " name=" << conn.outside_node_name; + auto dst_node = graph->FindNodeId(conn.outside_id); + // dst_node can only be removed if it is an input node of another engine. + // In this case, other engines input edge is updated in nodedef to point to + // this engine. Even though edge doesn't exists in the graph, when it is + // deserialized again, correct edges will be constructed. This is a problem + // of graph->AddNode(). + if (!dst_node) continue; + VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number + << " to " << dst_node->name() << ":" << conn.outside_port; + auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node, + conn.outside_port); + CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":" + << conn.port_number << " -> " << dst_node->name() << ":" + << conn.outside_port; + } + return status; +} + +// Function to construct a funcdef from the segment and add it to the graph. +tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( + tensorflow::Graph* graph, const tensorflow::GraphDef& segment, + const string& name) { + tensorflow::Graph sgraph(graph->flib_def()); + tensorflow::GraphConstructorOptions gcopts; + TF_RETURN_IF_ERROR( + tensorflow::ConvertGraphDefToGraph(gcopts, segment, &sgraph)); + std::map io_nodes; + int num_inputs = 0; + for (auto n : sgraph.op_nodes()) { + if (tensorflow::str_util::StartsWith(n->name(), kInputPHName)) { + num_inputs++; + io_nodes.insert({n->name(), n}); + } else if (tensorflow::str_util::StartsWith(n->name(), kOutputPHName)) { + io_nodes.insert({n->name(), n}); + } + } + + for (int i = 0; i < num_inputs; ++i) { + auto name = StrCat(kInputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Arg"), tensorflow::FunctionLibraryDefinition::kArgOp); + VLOG(1) << "Adding " << StrCat(name, "_Arg"); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + tensorflow::Status s; + auto node_arg = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Arg node for " << name; + } + for (auto edge : node->out_edges()) { + sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input()); + VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0 + << " - > " << edge->dst()->name() << ":" << edge->dst_input(); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << node_arg->name() + << " to " << edge->dst()->name() << ":" << edge->dst_input(); + } + } + sgraph.RemoveNode(node); + } + + for (int i = 0; i < io_nodes.size() - num_inputs; ++i) { + auto name = StrCat(kOutputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Ret"), tensorflow::FunctionLibraryDefinition::kRetOp); + auto edge = *(node->in_edges().begin()); + tensorflow::NodeDefBuilder::NodeOut nout( + edge->src()->name(), edge->src_output(), + edge->src()->output_type(edge->src_output())); + VLOG(1) << " input " << nout.node << ":" << nout.index + << " dtype=" << tensorflow::DataTypeString(nout.data_type); + node_builder.Input({nout}); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + if (VLOG_IS_ON(3)) { + VLOG(3) << nd.DebugString(); + } + tensorflow::Status s; + auto node_ret = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Ret node for " << name; + } + VLOG(1) << "Update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" << 0; + sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0); + s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" + << 0; + } + sgraph.RemoveNode(node); + } + tensorflow::FunctionDefLibrary fdeflib; + auto native_segment = fdeflib.add_function(); + TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( + sgraph, StrCat(name, "_native_segment"), native_segment)); + if (VLOG_IS_ON(7)) { + VLOG(7) << name << " Function_Def "; + VLOG(7) << native_segment->DebugString(); + } + VLOG(1) << "Adding funcdef to graphlib"; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib)); + return tensorflow::Status::OK(); +} + +std::pair GetDeviceAndAllocator( + ConversionParams& params, EngineInfo& engine) { + int cuda_device_id = -1; + auto check_device_id = [](int tfid) -> int { + tensorflow::TfGpuId tf_gpu_id(tfid); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (s.ok()) { + VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " + << cuda_gpu_id.value(); + return cuda_gpu_id.value(); + } + VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s; + return -1; + }; + tensorflow::Allocator* dev_allocator = nullptr; + // we need to us PM here since in python path there is no way to get + // to allocators. + // TODO(sami): when grappler devices become available else path will not be + // necessary + auto pm = tensorflow::ProcessState::singleton(); + if (params.cluster) { // get allocator + tensorflow::Device* device = nullptr; + if (params.cluster->GetDeviceSet()) { + device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device); + } + if (device) { + tensorflow::AllocatorAttributes alloc_attr; + dev_allocator = device->GetAllocator(alloc_attr); + VLOG(1) << "Using allocator " << dev_allocator->Name(); + } else { + LOG(WARNING) << "Cluster is set but device '" << engine.device + << "' is not found in the cluster"; + } + } else { // cluster not found, possibly a python call + VLOG(1) << "Cluster is not set, probably called from python"; + int found_device = 0; + bool try_gpu_ids = true; + // if device is set, try to find the device. Might be a problem for multi + // host case but TensorRT do not support multi host setups yet. + if (!engine.device.empty()) { + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) { + cuda_device_id = parsed_name.has_id ? parsed_name.id : -1; + } + try_gpu_ids = !parsed_name.has_id; + } + if (try_gpu_ids) { + while (found_device < 100) { + cuda_device_id = check_device_id(found_device); + if (cuda_device_id >= 0) break; + found_device++; + } + } + if (found_device == 100) { + LOG(ERROR) << " Can't find a GPU device to work with. Please " + "instantiate a session to initialize devices"; + return std::make_pair(cuda_device_id, dev_allocator); + } + LOG(WARNING) + << "Can't determine the device, constructing an allocator at device " + << found_device; + tensorflow::GPUOptions gpuoptions; + // this will be a noop if device is already initialized + gpuoptions.set_allow_growth(true); + tensorflow::TfGpuId tf_gpu_id(found_device); + dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); + } + return std::make_pair(cuda_device_id, dev_allocator); +} + +// Entry function from optimization pass. +tensorflow::Status ConvertAfterShapes(ConversionParams& params) { + // Convert graphdef to graph. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), - gdef.library()); + params.input_graph_def->library()); tensorflow::Graph graph(flib); TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), gdef, &graph)); + tensorflow::GraphConstructorOptions(), *params.input_graph_def, &graph)); + // Segment the graph into subgraphs that can be converted to TensorRT + tensorflow::tensorrt::segment::SegmentOptions segment_options; // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) - for (auto node : output_names) { + for (auto node : *(params.output_names)) { segment_options.exclude_node_list.insert(node); } - - // TODO(sami): this should be passed as a knob!!!! - segment_options.minimum_segment_size = minimum_segment_size; - tensorflow::tensorrt::segment::SegmentNodesVector segments; + segment_options.minimum_segment_size = params.minimum_segment_size; + tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, segment_options, &segments)); - if (segments.size() > 1) { - VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); + &graph, IsTensorRTCandidate, segment_options, &initial_segments)); + if (initial_segments.size() > 1) { + VLOG(0) << "MULTIPLE tensorrt candidate conversion: " + << initial_segments.size(); } + + // Get the EngineInfo for each segment. std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); - std::unordered_map> output_edge_map; - int count = 0; float total_num_nodes_in_segments = 0.; - for (auto s : segments) { - total_num_nodes_in_segments += s.first.size(); - } - // We create the map here since cluster may not be available in all cases. - std::map name_to_device_map; - if (cluster) { - // TODO(aaroey): consider using DeviceSet::FindDeviceByName(), as in a - // distributed environment, devices from different workers can have same - // short name. - for (const auto dm : cluster->GetDeviceSet()->devices()) { - name_to_device_map[dm->name()] = dm; + std::vector engine_segments; + engine_segments.reserve(initial_segments.size()); + std::vector reverse_topo_order; + tensorflow::GetPostOrder(graph, &reverse_topo_order); + size_t total_engine_bytes_size = 0; + std::vector engine_bytes_size; + tensorflow::tensorrt::segment::SegmentNodesVector converted_segments; + converted_segments.reserve(initial_segments.size()); + for (size_t t = 0; t < initial_segments.size(); t++) { + auto& curr_segment = initial_segments.at(t); + EngineInfo curr_engine; + Status status = + GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, + node_map, reverse_topo_order, &curr_engine); + if (!status.ok()) { + LOG(WARNING) << "Failed to get engine info for segment " << t << ": " + << status; + continue; } - } - for (const auto& segment_nodes_and_device : segments) { - const std::set& subgraph_node_names = - segment_nodes_and_device.first; - std::set subgraph_node_ids; - size_t max_mem_per_engine = - max_workspace_size_bytes * - ((float)subgraph_node_names.size() / total_num_nodes_in_segments); - std::stringstream oss; - for (const string& node_name : subgraph_node_names) { - oss << " " << node_name; - subgraph_node_ids.insert(node_map.at(node_name)->id()); + curr_engine.precision_mode = params.precision_mode; + curr_engine.engine_type = + (params.is_dyn_op || params.precision_mode == INT8MODE + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic); + curr_engine.cached_engine_batches = params.cached_engine_batches; + curr_engine.maximum_cached_engines = params.max_cached_engines; + StrAppend(&curr_engine.engine_name, "my_trt_op_", t); + status = RegisterSegmentFunctionToFunctionLibrary( + &graph, curr_engine.segment_graph_def, curr_engine.engine_name); + if (!status.ok()) { + LOG(WARNING) << "Failed to register segment graphdef as a function " << t + << ": " << status; + continue; } - VLOG(1) << "Subgraph nodes at device " << segment_nodes_and_device.second - << " : " << oss.str(); - auto target_device = - name_to_device_map.find(segment_nodes_and_device.second); - std::shared_ptr allocator(0); + engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); + total_engine_bytes_size += engine_bytes_size.back(); + total_num_nodes_in_segments += curr_segment.first.size(); + engine_segments.push_back(std::move(curr_engine)); + converted_segments.push_back(std::move(curr_segment)); + + if (VLOG_IS_ON(8)) { + string fname = curr_engine.engine_name; + StrAppend(&fname, ".pb"); + std::fstream f; + f.open(fname.c_str(), std::fstream::out | std::fstream::binary); + f << engine_segments.at(t).segment_graph_def.SerializeAsString(); + f.close(); + } + } + + // Create a TRT node for each segment using its EngineInfo. + int old_cuda_device = 0; + auto err = cudaGetDevice(&old_cuda_device); + if (err != cudaSuccess) { + LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); + } + VLOG(1) << "Current cuda device is " << old_cuda_device; + for (int i = 0; i < engine_segments.size(); ++i) { + auto& engine = engine_segments.at(i); + // Partition the workspace size by the average of node ratio and segment + // graphdef size + engine.max_workspace_size_bytes = + params.max_workspace_size_bytes * + (engine_bytes_size.at(i) / total_engine_bytes_size + + converted_segments.at(i).first.size() / total_num_nodes_in_segments) / + 2.0; + // The allocator is used to build the engine. The build and the built engine + // will be destroyed after we get the serialized engine string, so it's fine + // to use unique_ptr here. + std::unique_ptr alloc; + auto device_alloc = GetDeviceAndAllocator(params, engine); int cuda_device_id = 0; - if (target_device != name_to_device_map.end()) { - tensorflow::TfGpuId tf_gpu_id(target_device->second->parsed_name().id); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); - if (!s.ok()) { - LOG(ERROR) - << "Cuda device identification failed, using device 0. Error= " - << s; - } else { - cuda_device_id = cuda_gpu_id.value(); - } - tensorflow::GPUOptions gpuoptions; - // we need to us PM here since in python path there is no way to get to - // allocators - auto pm = tensorflow::ProcessState::singleton(); - // this should be instantiated by now - auto dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); - VLOG(1) << "Got an allocator for device tf_device=" << tf_gpu_id.value() - << " cuda device= " << cuda_device_id << " at " << dev_allocator; - allocator = std::make_shared(dev_allocator); - } else { // device unknown or not available - allocator = std::make_shared(); + if (device_alloc.first >= 0) { + cuda_device_id = device_alloc.first; + alloc.reset(new TRTDeviceAllocator(device_alloc.second)); + } else { + // Setting allocator as nullptr should get revert to the cudamalloc + LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } - ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size, - max_mem_per_engine, graph_properties, &output_edge_map, - precision_mode, segment_nodes_and_device.second, - allocator, cuda_device_id); - if (precision_mode == INT8MODE) { - tensorflow::Status status = GetCalibNode(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; + cudaSetDevice(cuda_device_id); + auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(), + params.max_batch_size); + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + if (status.ok()) { + for (auto node_name : converted_segments.at(i).first) { + graph.RemoveNode(node_map.at(node_name)); } } else { - tensorflow::Status status = ConvertSubGraphToTensorRT(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; - } + // Graph is not modified. + LOG(WARNING) << "Engine creation for segment " << i << ", composed of " + << converted_segments.at(i).first.size() + << " nodes failed: " << status << ". Skipping..."; } - count++; } - graph.ToGraphDef(new_graph_def); + cudaSetDevice(old_cuda_device); + graph.ToGraphDef(params.output_graph_def); + VLOG(1) << "Returning from conversion"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 65a67d7e73e32f904bd636a4f4aaefe32b0c092d..9d986e489043c0a0e16e379166aa2e8f7ac0b11f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -30,29 +30,60 @@ namespace tensorflow { namespace tensorrt { namespace convert { -// This method converts an already generated calibration graph which was used in -// calibration runs to an inference graph +struct ConversionParams { + ConversionParams() + : input_graph_def(nullptr), + max_batch_size(1), + max_workspace_size_bytes(1 << 30), + output_graph_def(nullptr), + precision_mode(1), + minimum_segment_size(3), + graph_properties(nullptr), + cluster(nullptr), + is_dyn_op(false), + fixed_input_size(true), + max_cached_engines(1) {} + const tensorflow::GraphDef* input_graph_def; + const std::vector* output_names; + size_t max_batch_size; + size_t max_workspace_size_bytes; + tensorflow::GraphDef* output_graph_def; + int precision_mode; + int minimum_segment_size; + const tensorflow::grappler::GraphProperties* graph_properties; + const tensorflow::grappler::Cluster* cluster; + bool is_dyn_op; // Whether to create engine on conversion or execution time + bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed + int max_cached_engines; // maximum number of cached engines + std::vector cached_engine_batches; // list of cached engines +}; + +// This method extracts calibration information from the resource managers +// and puts them in to engine nodedefs. tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def); + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, + bool is_dyn_op); -// max_batch_size: maximum batch size which can be used for inference for -// optimization targets inference run with max batch size. -// max_workspace_size_bytes: The upper bound of memory allowance for -// engine building. +// - max_batch_size: maximum batch size which can be used for inference for +// optimization targets inference run with max batch size. +// - max_workspace_size_bytes: The upper bound of memory allowance for engine +// building. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode, int minimum_segment_size); + int precision_mode = 1, int minimum_segment_size = 3, + bool is_dyn_op = false, int max_cached_engines = 1, + std::vector cached_engine_batches = {}); // Method to call from optimization pass -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& graph, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, - const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster); +tensorflow::Status ConvertAfterShapes(ConversionParams& params); + +// Return compile time TensorRT library version information. +std::vector GetLinkedTensorRTVersion(); + +// Return runtime time TensorRT library version information. +std::vector GetLoadedTensorRTVersion(); } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 4e4d295538edadd26a347a38ec141737f097f26f..146b9c7344b0a9c2b3ec87b395e9b1096dbef06c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -25,7 +24,9 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -54,8 +56,11 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::str_util::Split; + using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; + namespace { inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, @@ -121,12 +126,10 @@ static std::vector> CreateSamePadding( string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { size_t last_scope_separator = 0; - for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) { - if (op_name_a[i] != op_name_b[i]) { - break; - } else if (op_name_a[i] == '/') { - last_scope_separator = i + 1; - } + const size_t min_size = std::min(op_name_a.size(), op_name_b.size()); + for (size_t i = 0; i < min_size; ++i) { + if (op_name_a[i] != op_name_b[i]) break; + if (op_name_a[i] == '/') last_scope_separator = i + 1; } return op_name_a.substr(0, last_scope_separator); } @@ -417,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, } } -struct InferDeleter { - template - void operator()(T* obj) const { - if (obj) { - obj->destroy(); - } - } -}; - -template -inline std::shared_ptr infer_object(T* obj) { - return std::shared_ptr(obj, InferDeleter()); -} - class Converter; using OpConverter = @@ -444,7 +433,7 @@ class Converter { OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; - tensorflow::tensorrt::TRTWeightStore* weight_store_; + TRTWeightStore* weight_store_; bool fp16_; void register_op_converters(); tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def, @@ -486,11 +475,11 @@ class Converter { public: explicit Converter(nvinfer1::INetworkDefinition* trt_network, - tensorflow::tensorrt::TRTWeightStore* ws, bool fp16) + TRTWeightStore* ws, bool fp16) : trt_network_(trt_network), weight_store_(ws), fp16_(fp16) { this->register_op_converters(); } - tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; } + TRTWeightStore* weight_store() { return weight_store_; } TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, nvinfer1::Dims shape) { TRT_ShapedWeights weights(type, nullptr, shape); @@ -2140,559 +2129,265 @@ void Converter::register_op_converters() { } // namespace -tensorflow::Status ConvertCalibrationNodeToEngineNode( - tensorflow::Graph& graph, tensorflow::Node* c_node) { - const auto ndef = c_node->def(); - - TFAttrs attrs(ndef); - std::vector segment_nodes( - attrs.get>("segment_nodes")); - std::vector output_nodes( - attrs.get>("segment_output_names")); - std::vector input_names( - attrs.get>("input_names")); - string res_name = attrs.get("resource_name"); - VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name; - string engine_name = "my_trt_op"; - { - const auto node_id = tensorflow::str_util::Split(res_name, "_"); - engine_name += node_id.back(); - } - std::map node_maps; - - for (auto n : graph.op_nodes()) { - node_maps.insert({n->name(), n}); - } - std::set subgraph_ids; - for (const auto internal_node : segment_nodes) { - subgraph_ids.insert(node_maps.at(internal_node)->id()); - } - if (VLOG_IS_ON(2)) { - string node_names = StrCat(c_node->name(), " segment nodes= "); - - for (const auto& node_name : segment_nodes) { - StrAppend(&node_names, node_name, ", "); - } - VLOG(2) << node_names; +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully) { + engine->reset(); + if (convert_successfully) *convert_successfully = false; + + // Create the builder. + TrtUniquePtrType builder( + nvinfer1::createInferBuilder(*logger)); + builder->setMaxBatchSize(max_batch_size); + // TODO(aaroey): use the allocator to allocate the TRT workspace. + builder->setMaxWorkspaceSize(max_workspace_size_bytes); +#if NV_TENSORRT_MAJOR > 3 + builder->setGpuAllocator(allocator); +#endif + if (precision_mode == FP16MODE) { + builder->setHalf2Mode(true); + } else if (precision_mode == INT8MODE) { + builder->setInt8Mode(true); + builder->setInt8Calibrator(calibrator); } - VLOG(1) << "Output Nodes:"; - std::vector out_types; - std::vector out_edges; + // Create the network. + auto trt_network = + TrtUniquePtrType(builder->createNetwork()); + if (!trt_network) { + return tensorflow::errors::Internal( + "Failed to create TensorRT network object"); + } + auto ws = std::unique_ptr(new TRTWeightStore()); - for (auto& i : output_nodes) { - auto node_port = tensorflow::str_util::Split(i, ":"); - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto out_node_name = node_port.at(0); - if (node_port.size() > 1) { - VLOG(1) << "Multi port output" << node_port.at(0) << " " - << node_port.at(1) << " size=" << node_port.size(); - } - auto node_it = node_maps.find(out_node_name); - if (node_it != node_maps.end()) { - tensorflow::Node* out_node = node_it->second; - int port = 0; - if (node_port.size() == 2) { - port = std::strtoul(node_port.at(1).c_str(), nullptr, 10); - out_types.push_back(out_node->output_type(port)); - } else { - out_types.push_back(out_node->output_type(0)); + // Build the network + VLOG(1) << "Starting engine conversion "; + Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE); + std::vector> output_tensors; + // Graph nodes are already topologically sorted during construction + for (const auto& node_def : gdef.node()) { + string node_name = node_def.name(); + VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op(); + if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && + (node_def.op() == "Placeholder")) { + nvinfer1::DimsCHW input_dim_pseudo_chw; + for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0; + nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); + auto type_status = + ConvertDType(node_def.attr().at("dtype").type(), &dtype); + if (type_status != tensorflow::Status::OK()) { + LOG(WARNING) << "Type conversion failed for " << node_name; + return type_status; } - for (auto out_edge : out_node->out_edges()) { - if (subgraph_ids.count(out_edge->dst()->id())) - continue; // skip internal edges; - if (out_edge->src_output() == port) { - out_edges.push_back(out_edge); - VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":" - << out_edge->src_output() << " -> " << out_edge->dst()->name() - << ":" << out_edge->dst_input(); + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8, + &slot_number)) { + LOG(ERROR) << "Failed to parse slot number from " << node_name + << " +8= " << node_name.c_str() + 8; + } + auto shape = input_shapes.at(slot_number); + if (shape.dims() > 8) { + LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name + << " at input slot " << slot_number; + return tensorflow::errors::OutOfRange( + "Input tensor rank is greater than 8"); + } + if (VLOG_IS_ON(1)) { + string dim_str("dims="); + StrAppend(&dim_str, "[ ", shape.dim_size(0)); + for (int i = 1; i < shape.dims(); i++) { + StrAppend(&dim_str, ", ", shape.dim_size(i)); } + StrAppend(&dim_str, " ]"); + VLOG(1) << dim_str; + } + for (int i = 1; i < shape.dims(); i++) { + input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i); } - } else { - LOG(WARNING) << " couldn't find output node " << out_node_name; - } - } - if (VLOG_IS_ON(1)) { - VLOG(1) << c_node->name() << " Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " Input " << i << " in graph " << node_maps.count(i); - } - } - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto resmgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = resmgr->Lookup(res_name, res_name, &calib_res); - if (!status.ok() || !calib_res->calibrator_) { - return tensorflow::errors::FailedPrecondition( - "You must run calibration" - " and inference conversion in the same process"); - } - - calib_res->calibrator_->setDone(); - calib_res->thr_->join(); - delete calib_res->thr_; - if (!calib_res->engine_) { - LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run " - "calibration graph?"; - return tensorflow::errors::FailedPrecondition( - "Calibration graph needs to be executed on" - " calibration data before convertsion to inference graph"); - } - auto weight_rmgr = trt_rm->getManager("WeightStore"); - TF_CHECK_OK(weight_rmgr->Delete( - res_name, res_name)); - auto engine_plan = calib_res->engine_->serialize(); - calib_res->engine_->destroy(); - calib_res->network_->destroy(); - calib_res->builder_->destroy(); - calib_res->thr_ = nullptr; - calib_res->engine_ = nullptr; - calib_res->builder_ = nullptr; - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - std::vector income_edges; - income_edges.resize(c_node->num_inputs()); - for (const auto in_edge : c_node->in_edges()) { - auto src = in_edge->src(); - int dest_port = in_edge->dst_input(); - VLOG(1) << "Incoming connection " << src->name() << ":" - << in_edge->src_output() << " -> " << c_node->name() << ":" - << dest_port; - income_edges.at(dest_port) = {src->name(), in_edge->src_output(), - c_node->input_type(dest_port)}; - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - if (VLOG_IS_ON(2)) { - for (const auto& inp : input_list) { - VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " " - << tensorflow::DataTypeString(inp.data_type); - } - } - op_builder.Input(input_list); - tensorflow::NodeDef engine_node; - const char* engine_plan_data = static_cast(engine_plan->data()); - string engine_plan_string(engine_plan_data, - engine_plan_data + engine_plan->size()); - status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_nodes) - .Attr("OutT", out_types) - .Finalize(&engine_node); - if (!status.ok()) { - LOG(ERROR) << "Engine Node creation failed"; - return status; - } - auto trt_engine_node = graph.AddNode(engine_node, &status); - TF_RETURN_IF_ERROR(status); - std::map port_map; - for (size_t t = 0; t < output_nodes.size(); t++) { - port_map.insert({output_nodes.at(t), t}); - } - for (auto& i : out_edges) { - string s(i->src()->name()); - if (i->src_output()) StrAppend(&s, ":", i->src_output()); - int out_port = port_map.at(s); - VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port - << " -> " << i->dst()->name() << ":" << i->dst_input(); - TF_RETURN_IF_ERROR( - graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); - } - for (const auto ed : trt_engine_node->in_edges()) { - VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output() - << " -> " << ed->dst()->name() << ":" << ed->dst_input(); - } - for (const auto ed : trt_engine_node->out_edges()) { - VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() - << " -> " << ed->dst()->name() << ":" << ed->dst_input(); - } - VLOG(1) << "Segment nodes:"; - for (auto& i : segment_nodes) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto it = node_maps.find(i); - if (it != node_maps.end()) { - graph.RemoveNode(it->second); - } - } - graph.RemoveNode(c_node); - return tensorflow::Status::OK(); -} -tensorflow::Status ReverseTopologicalSort( - const tensorrt::convert::SubGraphParams& s, - std::list* order) { - std::vector order_vec; - tensorflow::GetPostOrder(s.graph, &order_vec); - // Select just the subgraph - for (tensorflow::Node* node : order_vec) { - if (s.subgraph_node_ids.count(node->id())) { - // We want topological order to contstruct the - // network layer by layer - order->push_front(node); + input_dim_pseudo_chw.nbDims = shape.dims() - 1; + nvinfer1::ITensor* input_tensor = converter.network()->addInput( + node_name.c_str(), dtype, input_dim_pseudo_chw); + if (!input_tensor) { + return tensorflow::errors::InvalidArgument( + "Failed to create Input layer tensor ", node_name, + " rank=", shape.dims() - 1); + } + VLOG(1) << "Input tensor name :" << node_name; + if (!converter.insert_input_tensor(node_name, input_tensor)) { + return tensorflow::errors::AlreadyExists( + "Output tensor already exists for op: " + node_name); + } + } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && + (node_def.op() == "Identity")) { + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9, + &slot_number)) { + LOG(ERROR) << "Failed to parse slot number from " << node_name + << " +9=" << node_name.c_str() + 9; + } + if (output_tensors.size() <= slot_number) { + output_tensors.resize(slot_number + 1); + } + output_tensors.at(slot_number) = {node_def.input(0), node_name}; + } else { + VLOG(2) << "Converting node: " << node_def.name() << " , " + << node_def.op(); + TF_RETURN_IF_ERROR(converter.convert_node(node_def)); } } - return tensorflow::Status::OK(); -} - -tensorflow::Status SetInputList( - const tensorrt::convert::SubGraphParams& s, - tensorflow::NodeDefBuilder* op_builder, - const std::vector* input_names, - std::vector* input_dtypes) { - std::vector income_edges; - VLOG(2) << "input edge size: " << input_names->size(); - for (size_t i = 0; i < input_names->size(); ++i) { - VLOG(2) << "input edges: " << i << " " << input_names->at(i); - int output_idx = s.input_inds.at(i).second; - // we wired up the input here already, it is redundant to do it again in - // ConvertSubGraphToTensorRT(convert_graph.cc) - auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( - input_names->at(i), output_idx, input_dtypes->at(i)); - income_edges.push_back(incoming_edge); - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - op_builder->Input(input_list); - return tensorflow::Status::OK(); -} - -string SubgraphNameScopeGenerator(const std::list* order) { - string subgraph_name_scope; - if (!order->empty()) { - subgraph_name_scope = order->front()->name(); - } - for (const tensorflow::Node* node : *order) { - subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name()); - } - // TODO(sami,ben,jie): proper naming! - return subgraph_name_scope; -} - -tensorflow::Status ConvertSubgraph( - Converter& converter, tensorrt::convert::SubGraphParams& s, - std::list* order, std::vector* input_names, - std::vector* input_dtypes, - std::vector* output_names, - std::vector* output_dtypes, - const string& engine_name) { - std::set added_tensors; - for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input. Node id= " << input.first; - int node_id = input.first; - int output_idx = input.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - auto node_name = node->name(); - // input_names should use the node name in the graph - // here it should be the input tensor name -> matching the binding - // insert original node name without port - auto tensor_name = node_name; - if (output_idx != 0) { - tensor_name = StrCat(tensor_name, ":", output_idx); - } - - VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name - << " idx: " << output_idx; - - auto shape_inference_node_name = node_name; - auto shape_inference_output_idx = output_idx; - // rewire the shape inference to original node in the graph - if (s.output_edge_map->count(tensor_name)) { - shape_inference_node_name = s.output_edge_map->at(tensor_name).second; - shape_inference_output_idx = s.output_edge_map->at(tensor_name).first; - } - if (shape_inference_output_idx < 0) continue; - VLOG(2) << "shapeinference name: " << shape_inference_node_name - << " idx: " << shape_inference_output_idx; - - if (!s.graph_properties.HasOutputProperties(shape_inference_node_name)) - return tensorflow::errors::Internal("failed to find input node: " + - shape_inference_node_name); - - auto op_info_vec = - s.graph_properties.GetOutputProperties(shape_inference_node_name); - if (static_cast(op_info_vec.size()) <= shape_inference_output_idx) - return tensorflow::errors::Internal( - "accessing output index of: ", shape_inference_output_idx, - ", at node: ", shape_inference_node_name, - " with output entry from shape_map: ", op_info_vec.size()); - - auto op_info = op_info_vec.at(shape_inference_output_idx); - tensorflow::DataType tf_dtype = op_info.dtype(); - - nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - auto type_status = ConvertDType(tf_dtype, &dtype); - if (type_status != tensorflow::Status::OK()) { - LOG(WARNING) << "Type conversion failed for " << node_name; - return type_status; - } - - VLOG(2) << "Accessing output index of: " << output_idx - << ", at node: " << node_name - << " with output entry from shape_map: " << op_info_vec.size(); - // TODO(ben,jie): update TRT input format/dimension - nvinfer1::DimsCHW input_dim_pseudo_chw; - for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1; - - // TODO(jie): TRT 3.x only support 4 dimensional input tensor. - // update the code once TRT 4.0 comes out. - if (op_info.shape().dim_size() != 4) { - string err_str = "Require 4 dimensional input."; - StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ", - shape_inference_node_name); - return tensorflow::errors::Unimplemented(err_str); - } - - for (int i = 1; i < op_info.shape().dim_size(); i++) { - VLOG(2) << "dimension: " << i - << " , size: " << op_info.shape().dim(i).size(); - input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size(); - } - - // TODO(ben,jie): proper way to restore input tensor name? - auto input_tensor_name = node_name; - if (output_idx != 0) { - input_tensor_name = StrCat(node_name, ":", output_idx); - } - if (added_tensors.count(input_tensor_name)) continue; - added_tensors.insert(input_tensor_name); - input_names->push_back(input_tensor_name); - input_dtypes->push_back(tf_dtype); - nvinfer1::ITensor* input_tensor = converter.network()->addInput( - input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); - - if (!input_tensor) - return tensorflow::errors::InvalidArgument( - "Failed to create Input layer"); - VLOG(2) << "Input tensor name :" << input_tensor_name; - - if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) - return tensorflow::errors::AlreadyExists( - "Output tensor already exists for op: " + input_tensor_name); - } - - for (const tensorflow::Node* node : *order) { - const tensorflow::NodeDef& node_def = node->def(); - VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); - TF_RETURN_IF_ERROR(converter.convert_node(node_def)); - } - - VLOG(2) << "Finished conversion"; - - // Gather output metadata - int trt_engine_op_output_idx = 0; - added_tensors.clear(); - for (const std::pair& output : s.output_inds) { - int node_id = output.first; - int output_idx = output.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - string op_name = node->name(); - string tensor_name = op_name; - - s.output_edge_map->insert( - {trt_engine_op_output_idx == 0 - ? engine_name - : StrCat(engine_name, ":", trt_engine_op_output_idx), - {output_idx, tensor_name}}); - trt_engine_op_output_idx++; - if (output_idx != 0) - tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); - VLOG(2) << "Output tensor name: " << tensor_name; - if (added_tensors.count(tensor_name)) continue; - added_tensors.insert(tensor_name); - output_names->push_back(tensor_name); - auto tensor_or_weights = converter.get_tensor(tensor_name); + for (const auto& output : output_tensors) { + auto tensor_or_weights = converter.get_tensor(output.first); if (!tensor_or_weights.is_tensor()) { - return tensorflow::errors::InvalidArgument("Output node '" + tensor_name + - "' is weights not tensor"); + return tensorflow::errors::InvalidArgument( + "Output node '" + output.first + "' is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); + tensor->setName(output.second.c_str()); if (!tensor) { return tensorflow::errors::NotFound("Output tensor not found: " + - tensor_name); + output.first); } + VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " + << output.second; + converter.network()->markOutput(*tensor); - tensorflow::DataType tf_dtype = node->output_type(output_idx); - output_dtypes->push_back(tf_dtype); - nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; - TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); - tensor->setType(trt_dtype); } + if (convert_successfully) *convert_successfully = true; - return tensorflow::Status::OK(); -} - -tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - // Toposort - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - // TODO(sami,ben,jie): proper naming! - string calib_op_name = - StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id); - static_id++; - - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto op_rmgr = trt_rmgr->getManager("TRTCalibOps"); - auto op_res = new tensorflow::tensorrt::TRTCalibrationResource(); - TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res)); - op_res->logger_ = new tensorflow::tensorrt::Logger(); - cudaSetDevice(s.cuda_gpu_id_); - op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_)); - op_res->allocator_ = s.allocator_; -#if NV_TENSORRT_MAJOR > 3 - op_res->builder_->setGpuAllocator(s.allocator_.get()); -#endif - if (!op_res->builder_) { - return tensorflow::errors::Internal( - "failed to create TensorRT builder object"); + // Build the engine. + VLOG(1) << "Starting engine creation"; + engine->reset(builder->buildCudaEngine(*converter.network())); + if (engine->get() == nullptr) { + return tensorflow::errors::Internal("Failed to build TensorRT engine"); } - - op_res->network_ = op_res->builder_->createNetwork(); - if (!op_res->network_) { - return tensorflow::errors::Internal( - "failed to create TensorRT network object"); - } - - // Build the network - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws)); - Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished processing outputs"; - - // Build the engine - op_res->builder_->setMaxBatchSize(s.max_batch_size); - op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - - // Build the TRT op - // TODO(sami,ben,jie): proper naming! - tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp"); - TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); - - std::vector segment_names; - segment_names.reserve(s.subgraph_node_ids.size()); - for (int i : s.subgraph_node_ids) { - auto node = s.graph.FindNodeId(i); - segment_names.push_back(node->name()); - } - LOG(INFO) << "finished op preparation"; - - auto status = op_builder.Attr("segment_nodes", segment_names) - .Attr("input_names", input_names) - .Attr("segment_output_names", output_names) - .Attr("resource_name", calib_op_name) - .Finalize(s.trt_node); - - LOG(INFO) << status.ToString(); - LOG(INFO) << "finished op building"; - + VLOG(1) << "Finished conversion"; return tensorflow::Status::OK(); } -tensorflow::Status ConvertSubGraphToTensorRTNodeDef( - tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id++); - - tensorflow::tensorrt::Logger trt_logger; - cudaSetDevice(s.cuda_gpu_id_); - auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); - if (!trt_builder) { - return tensorflow::errors::Internal( - "Failed to create TensorRT builder object"); - } -#if NV_TENSORRT_MAJOR > 3 - trt_builder->setGpuAllocator(s.allocator_.get()); -#endif - auto trt_network = infer_object(trt_builder->createNetwork()); - if (!trt_network) { - return tensorflow::errors::Internal( - "Failed to create TensorRT network object"); - } - - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws)); - - // Build the network - Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished output"; - - // Build the engine - trt_builder->setMaxBatchSize(s.max_batch_size); - trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - if (s.precision_mode == FP16MODE) { - trt_builder->setHalf2Mode(true); - VLOG(0) << "Using FP16 precision mode"; - } - LOG(INFO) << "starting build engine"; - string engine_plan_string; - { - auto trt_engine = - infer_object(trt_builder->buildCudaEngine(*converter.network())); - VLOG(0) << "Built network"; - if (trt_engine.get() == nullptr) { - return tensorflow::errors::Internal("Engine building failure"); +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::vector& subgraph_node_ids, // In topological order + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope) { + std::set marker_nodes; + // Update connection shapes/data types and add corresponding input/output + // nodes in the segment graphdef. + for (size_t i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + auto outside_node = graph->FindNodeId(connection.outside_id); + if (!outside_node) { + // This should never happen, unless the original graph is problematic. + return tensorflow::errors::NotFound( + "Cannot find node with id ", connection.outside_id, " in the graph."); + } + // Updates the shape and data types of input/output connections. + tensorflow::DataType input_type = tensorflow::DT_FLOAT; + tensorflow::PartialTensorShape partial_shape; + if (connection.is_input_edge) { + if (graph_properties.HasOutputProperties(connection.outside_node_name)) { + auto output_params = + graph_properties.GetOutputProperties(connection.outside_node_name); + auto out_shape = output_params.at(connection.outside_port); + input_type = out_shape.dtype(); + std::vector dims; + partial_shape = out_shape.shape(); + connection.outside_shape = partial_shape; + } else { + VLOG(0) << "Unknown output shape" << outside_node->name(); + input_type = graph->FindNodeId(connection.outside_id) + ->output_type(connection.outside_port); + } + connection.connection_type = input_type; + + } else { // output edge + if (graph_properties.HasInputProperties(connection.outside_node_name)) { + auto input_params = + graph_properties.GetInputProperties(connection.outside_node_name); + auto in_shape = input_params.at(connection.outside_port); + input_type = in_shape.dtype(); + partial_shape = in_shape.shape(); + connection.inside_shape = partial_shape; + } else { + input_type = graph->FindNodeId(connection.inside_id) + ->output_type(connection.outside_port); + } + connection.connection_type = input_type; } - auto engine_plan = infer_object(trt_engine->serialize()); - VLOG(0) << "Serialized engine"; - const char* engine_plan_data = - static_cast(engine_plan->data()); - engine_plan_string = - string(engine_plan_data, engine_plan_data + engine_plan->size()); - } - TF_RETURN_IF_ERROR(weight_rmgr->Delete( - engine_name, engine_name)); - LOG(INFO) << "finished engine " << engine_name << " containing " - << s.subgraph_node_ids.size() << " nodes"; - - // Build the TRT op - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); - - VLOG(0) << "Finished op preparation"; - - auto status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_names) - .Attr("OutT", output_dtypes) - .Device(s.device_name_) - .Finalize(s.trt_node); - - VLOG(0) << status.ToString() << " finished op building for " << engine_name - << " on device " << s.device_name_; + // Add dummy input/output nodes to the segment graphdef. + if (connection.is_input_edge) { + const string node_name = StrCat(kInputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing input " << node_name << " for the edge " + << connection.outside_node_name << ":" + << connection.outside_port << " -> " + << connection.inside_node_name << ":" << connection.inside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); + auto status = builder.Attr("shape", partial_shape) + .Attr("dtype", input_type) + .Finalize(seg_node); + VLOG(1) << "Constructing input " << node_name << " for the edge " + << connection.outside_node_name << ":" << connection.outside_port + << " -> " << connection.inside_node_name << ":" + << connection.inside_port; + } else { + const string node_name = StrCat(kOutputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Identity"); + auto status = builder.Input(connection.inside_node_name, 0, input_type) + .Finalize(seg_node); + VLOG(1) << "Constructing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + } + } // for each connection. + + std::unordered_map old_to_new_id_map; + // Copy internal nodes to new graphdef + string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); + for (const auto node_id : subgraph_node_ids) { + const auto node = graph->FindNodeId(node_id); + local_scope = GetCommonNameScope(local_scope, node->name()); + old_to_new_id_map[node_id] = segment_def->node_size(); + auto snode = segment_def->add_node(); + snode->CopyFrom(node->def()); + VLOG(1) << "Copying " << snode->name() << " to subgraph"; + } + // Update the inputs of the new input nodes to point to placeholder nodes. + for (int i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + if (!connection.is_input_edge) continue; + auto snode = + segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); + const string placeholder_name = + StrCat(kInputPHName, connection.port_number); + VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port + << " from " << snode->input(connection.inside_port) << " to " + << placeholder_name; + snode->set_input(connection.inside_port, placeholder_name); + } + *common_scope = local_scope; + VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 3f6592cd25ff013cadc0621ba64f0553983dd10b..1a4c0e755d1cd1e88ac26c39996eb3a750421a0a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -22,69 +22,112 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { +static const char* kInputPHName = "InputPH_"; +static const char* kOutputPHName = "OutputPH_"; namespace convert { +// TODO(aaroey): use an enum instead. const int FP32MODE = 0; const int FP16MODE = 1; const int INT8MODE = 2; -struct SubGraphParams { - SubGraphParams( - tensorflow::Graph& inp_graph, - const std::set& subgraph_node_id_numbers, - const std::vector>& input_indices, - const std::vector>& output_indices, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - tensorflow::NodeDef* constructed_trt_node, - int engine_precision_mode = FP32MODE, const string& device_name = "", - std::shared_ptr allocator = nullptr, - int cuda_gpu_id = 0) - : graph(inp_graph), - subgraph_node_ids(subgraph_node_id_numbers), - input_inds(input_indices), - output_inds(output_indices), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - trt_node(constructed_trt_node), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - - tensorflow::Graph& graph; - const std::set& subgraph_node_ids; - const std::vector>& input_inds; // {node_id, output_idx} - const std::vector>& output_inds; // {node_id, output_idx} - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - tensorflow::NodeDef* trt_node; - const int precision_mode; - const string device_name_; - std::shared_ptr allocator_; - const int cuda_gpu_id_; +struct EngineConnection { + EngineConnection(const string& outside, int out_id, int out_port, + const string& inside, int in_id, int in_port, + bool input_edge, int port) + : outside_node_name(outside), + outside_id(out_id), + outside_port(out_port), + inside_node_name(inside), + inside_id(in_id), + inside_port(in_port), + is_input_edge(input_edge), + port_number(port) {} + + const string outside_node_name; + const int outside_id; + const int outside_port; + tensorflow::PartialTensorShape outside_shape; + + const string inside_node_name; + const int inside_id; + const int inside_port; + tensorflow::PartialTensorShape inside_shape; + + tensorflow::DataType connection_type; + bool is_input_edge; + + // The port number of the TRT node connecting to this edge. + int port_number; +}; + +struct EngineInfo { + EngineInfo() + : engine_type(EngineType::TRTStatic), + max_workspace_size_bytes(0), + precision_mode(FP32MODE) {} + + string engine_name; + string device; + tensorflow::GraphDef segment_graph_def; + + // The segment nodes that are on one side of the edges are topological sorted. + std::vector connections; + + enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; + EngineType engine_type; + int64 max_workspace_size_bytes; + int maximum_cached_engines; + std::vector cached_engine_batches; + int precision_mode; }; -// TODO(sami): Replace references with const reference or pointers -tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params); -tensorflow::Status InjectCalibrationNode(SubGraphParams& params); -tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph, - tensorflow::Node* c_node); +// Constructs a graphdef from the segment in the given graph. Adds placeholder +// nodes for input edges (InputPH_*) and identity nodes for output edges +// (OutputPH_*). This function needs to be called before TensorRT nodes +// inserted in order to correctly get sizes from the original graph. +// +// - subgraph_node_ids: the node ids of the subgraph, must be sorted in +// topological order. +// - segment_def: the output GraphDef, whose non-input/output nodedefs will be +// sorted in topological order. +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::vector& subgraph_node_ids, + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope); + +// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff +// 'builder' successfully build the engine. If the result is not ok, 'engine' +// will be set to nullptr +// Once returned, 'builder' is not needed any more and can be safely detroyed. +// +// - convert_successfully: indicates whether the converson to TensorRT network +// is successful. This is different than successfully building the engine: +// building can still fail afterwards. +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully); + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 8f634b1f74717310a69a6bab5d5224c9bdbf10cc..ec9dbfa13bfd0a158dcf41cf1fdb7128a2adf641 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -45,8 +45,24 @@ tensorflow::Status TRTOptimizationPass::Init( if (params.count("max_batch_size")) { maximum_batch_size_ = params.at("max_batch_size").i(); } - if (params.count("max_workspace_size_bytes")) + is_dynamic_op_ = false; + if (params.count("is_dynamic_op")) { + is_dynamic_op_ = params.at("is_dynamic_op").b(); + } + if (params.count("cached_engine_batches")) { + auto batch_vec = params.at("cached_engine_batches").list(); + batches_.reserve(batch_vec.i_size()); + for (const auto i : batch_vec.i()) { + batches_.push_back(i); + } + } + max_cached_batches_ = 1; + if (params.count("maximum_cached_engines")) { + max_cached_batches_ = params.at("maximum_cached_engines").i(); + } + if (params.count("max_workspace_size_bytes")) { maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); + } if (params.count("precision_mode")) { string pm = Uppercase(params.at("precision_mode").s()); if (pm == "FP32") { @@ -175,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize( if (VLOG_IS_ON(1)) { PrintDebugInfo(cluster, item); } + // This is a hack to workaround optimizer issue. MetaOptimizer calls + // optimization passes on function objects as well, we should not modify + // generated funcdefs! This is fragile but we don't have any other option + // until framework fixes it. + if (item.id != "tf_graph") { + LOG(WARNING) << name_ + << " is probably called on funcdef! This optimizer must *NOT* " + "be called on function objects."; + *optimized_graph = item.graph; + return tensorflow::Status::OK(); + } int max_dim = -1; if (item.feed.size()) { for (const auto& f : item.feed) { @@ -204,11 +231,22 @@ tensorflow::Status TRTOptimizationPass::Optimize( } tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - auto status = tensorflow::tensorrt::convert::ConvertAfterShapes( - item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_, - optimized_graph, precision_mode_, minimum_segment_size_, - static_graph_properties, cluster); + tensorflow::tensorrt::convert::ConversionParams cp; + cp.input_graph_def = &item.graph; + cp.output_names = &item.fetch; + cp.max_batch_size = maximum_batch_size_; + cp.max_workspace_size_bytes = maximum_workspace_size_; + cp.output_graph_def = optimized_graph; + cp.precision_mode = precision_mode_; + cp.minimum_segment_size = minimum_segment_size_; + cp.graph_properties = &static_graph_properties; + cp.cluster = cluster; + cp.is_dyn_op = is_dynamic_op_; + cp.cached_engine_batches = batches_; + cp.max_cached_engines = max_cached_batches_; + auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(2) << optimized_graph->DebugString(); + VLOG(1) << "Returning from " << name_; return status; } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index d8ecead23efaa5c3bab95b8ba481e2307b0af772..463ed3883e4808408104c618a289989472c497ea 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -61,6 +61,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { int minimum_segment_size_; int precision_mode_; int maximum_batch_size_; + bool is_dynamic_op_; + std::vector batches_; + int max_cached_batches_; int64_t maximum_workspace_size_; }; diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f601c06701fdbf983b708cf5f5c7d22634bb810b --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/utils.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ + +#include + +namespace tensorflow { +namespace tensorrt { + +template +struct TrtDestroyer { + void operator()(T* t) { + if (t) t->destroy(); + } +}; + +template +using TrtUniquePtrType = std::unique_ptr>; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc deleted file mode 100644 index aea44fd8a2fcc4c359a6cb0c98ae34711708326e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/stream_executor.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_)); - OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_)); -}; - -#define TYPECASE(dt, X, Y) \ - case dt: { \ - return (void*)X->flat::Type>().data(); \ - } - -void* GetTensorAddress(const Tensor* tensor_ptr) { - auto tensor_type = tensor_ptr->dtype(); - switch (tensor_type) { - TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); - default: { - LOG(FATAL) << "Unsupported Data type " - << tensorflow::DataTypeString(tensor_type); - return nullptr; - } - } -} - -void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) { - // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR. - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res); - - if (!status.ok()) { - ctx->SetStatus(status); - return; - } - int num_inputs = ctx->num_inputs(); - // first run instantiate calibrator - if (calib_res->calibrator_ == nullptr) { - dev_tensors_.resize(num_inputs); - int batch_size = ctx->input(0).dim_size(0); - VLOG(1) << " Constructing calibrator"; - for (int i = 0; i < num_inputs; i++) { - // allocate workspace on device for inputs - const tensorflow::Tensor& t = ctx->input(i); - OP_REQUIRES_OK(ctx, - ctx->allocate_persistent(t.dtype(), t.shape(), - &dev_tensors_.at(i), nullptr)); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); - void* device_address = GetTensorAddress(device_tensor); - device_buffers_.emplace(input_names_.at(i), - std::pair( - device_address, device_tensor->TotalBytes())); - } - - calib_res->calibrator_ = - new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_); - string label(resource_name_); - calib_res->thr_ = new std::thread([calib_res, label]() { - VLOG(1) << "Starting calibration thread, Calibration Resource @ " - << calib_res; - calib_res->builder_->setInt8Calibrator(calib_res->calibrator_); - calib_res->builder_->setInt8Mode(true); - calib_res->engine_ = calib_res->builder_->buildCudaEngine( - *calib_res->network_); // will loop until we terminate calibrator - VLOG(1) << "Calibration loop terminated " << label; - }); - VLOG(1) << "initialized calibrator resource"; - } // calibrator initialized - - // Pass input data to calibrator - std::unordered_map input_data; - for (int i = 0; i < num_inputs; i++) { - const Tensor& t = ctx->input(i); - void* data_address = GetTensorAddress(&t); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), - device_tensor->TotalBytes()); // use the tensor so FW keeps it - input_data.emplace(input_names_.at(i), data_address); - ctx->set_output(i, t); - } - VLOG(2) << "Filled map for sending"; - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files - const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(ctx->op_device_context() - ->stream() - ->implementation() - ->CudaStreamMemberHack())); - calib_res->calibrator_->setBatch(input_data, *stream); - VLOG(2) << "Passed calibration data"; - // TODO(aaroey): make sure we wait for the completion of calibration on the - // last batch in future PR. -}; - -#undef TYPECASE - -REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp); - -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h deleted file mode 100644 index 23df9db32f077a080eaff7479fcbe90d6a504c42..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H - -#include -#include -#include -#include -#include -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/types.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -namespace tensorflow { -namespace tensorrt { -// TODO(sami): Convert this to async kernel! -class TRTCalibOp : public OpKernel { - public: - explicit TRTCalibOp(OpKernelConstruction* context); - - void Compute(OpKernelContext* context) override; - - private: - string resource_name_; - std::vector segment_nodes_; - std::vector input_names_; - std::vector shapes_; - std::unordered_map> device_buffers_; - std::vector dev_tensors_; -}; -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 9ac8047944874181de228a6cc58e2dafe46abe50..8a17eb02f1af7c8f148c9cd4e14cc3876b6e13e3 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -14,8 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -25,144 +33,556 @@ limitations under the License. #include "cuda/include/cuda_runtime_api.h" namespace tensorflow { -static ::tensorflow::tensorrt::Logger logger; -using IRuntime = nvinfer1::IRuntime; -using Dims = nvinfer1::Dims; - namespace tensorrt { +static Logger logger; +using ::nvinfer1::IRuntime; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// A helper class to call done() when destructed for asynchronous execution. +// Helps simultaneous execution of native and TRT engines. +class AsyncHelper : public tensorflow::core::RefCounted { + public: + AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; } + ~AsyncHelper() override { done_(); } -TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { + private: + tensorflow::AsyncOpKernel::DoneCallback done_; +}; + +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ + } + +void* GetTensorAddress(const Tensor* tensor_ptr) { + auto tensor_type = tensor_ptr->dtype(); + switch (tensor_type) { + TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + default: { + LOG(ERROR) << "Unsupported Data type " + << tensorflow::DataTypeString(tensor_type); + return nullptr; + } + } +} + +tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { + VLOG(1) << "Constructing function handle"; + auto lib = ctx->function_library(); + if (lib == nullptr) { + return tensorflow::errors::Internal("Context function library is null"); + } + auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_); + if (fdef == nullptr) { + return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_, + " can't be found in function library"); + } + tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops; + inst_ops.overlay_lib = nullptr; + inst_ops.state_handle = ""; + inst_ops.target = ctx->device()->name(); + native_func_ = 0; + auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), + inst_ops, &native_func_); + if (!status.ok()) { + LOG(ERROR) << " Instantiating native function " << funcdef_name_ + << " failed!"; + } + return status; +} + +TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { // read serialized_engine OP_REQUIRES_OK(context, - context->GetAttr("serialized_engine", &serialized_engine_)); + context->GetAttr("serialized_segment", &serialized_segment_)); + OP_REQUIRES_OK(context, + context->GetAttr("workspace_size_bytes", &workspace_size_)); + OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_)); + if (!static_engine_) { + if (!segment_graph_.ParseFromString(serialized_segment_)) { + LOG(ERROR) << "Parsing segment graph failed!"; + context->SetStatus(tensorflow::errors::InvalidArgument( + "Failed to parse segment graphdef!")); + return; + } + serialized_segment_.resize(0); + } + VLOG(1) << "Constructing " << name(); + string precision_string; + OP_REQUIRES_OK(context, + context->GetAttr("precision_mode", &precision_string)); + string calibration_data; + OP_REQUIRES_OK(context, + context->GetAttr("calibration_data", &calibration_data)); + OP_REQUIRES_OK(context, + context->GetAttr("segment_funcdef_name", &funcdef_name_)); + if (precision_string == "FP32") { + precision_mode_ = convert::FP32MODE; + } else if (precision_string == "FP16") { + precision_mode_ = convert::FP16MODE; + } else if (precision_string == "INT8") { + precision_mode_ = convert::INT8MODE; + } + calibration_mode_ = + (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0); + if (calibration_data.size()) { + calibrator_.reset(new TRTInt8Calibrator(calibration_data)); + calibration_data.resize(0); + } + native_func_ = tensorflow::kInvalidHandle; + OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", + &max_cached_engines_)); + OP_REQUIRES_OK(context, + context->GetAttr("fixed_input_size", &fixed_input_size_)); + OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", + &cached_engine_batches_)); + std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); + if (VLOG_IS_ON(1)) { + string s("Engine Batches= "); + for (auto i : cached_engine_batches_) { + StrAppend(&s, i, " "); + } + VLOG(1) << s; + } +} - // register input output node name in trt_sub_graph - OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_)); +void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper) { + if (!calibration_mode_) { + VLOG(1) << "Executing native engine"; + } + std::vector inputs; + std::vector* outputs = new std::vector(); + if (native_func_ == tensorflow::kInvalidHandle) { + auto status = ConstructFunctionHandle(ctx); + if (!status.ok()) { + LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_; + ctx->SetStatus(status); + return; + } + } + auto lib = ctx->function_library(); + tensorflow::FunctionLibraryRuntime::Options opts; + opts.step_id = ctx->step_id(); + opts.rendezvous = ctx->rendezvous(); + opts.cancellation_manager = ctx->cancellation_manager(); + opts.runner = ctx->runner(); + for (int i = 0; i < ctx->num_inputs(); i++) { + inputs.push_back(ctx->input(i)); + } + helper->Ref(); // Increment count for calculating native graph + VLOG(1) << "Executing native segment " << name(); + lib->Run(opts, native_func_, inputs, outputs, + [ctx, outputs, helper](const tensorflow::Status& s) { + tensorflow::core::ScopedUnref sc(helper); + VLOG(1) << "Native Segment completed"; + if (!s.ok()) { + ctx->SetStatus(s); + return; + } + for (size_t t = 0; t < outputs->size(); ++t) { + ctx->set_output(t, outputs->at(t)); + } + delete outputs; + }); } -void TRTEngineOp::Compute(OpKernelContext* context) { - // TODO(samikama) runtime should be taken from a resourcemanager as well. - // Only engine should be in the op and context and runtime should be taken - // from resourcemanager +void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper) { + helper->Ref(); + tensorflow::core::ScopedUnref sc(helper); + // TODO(aaroey): remove the ResourceMgr singleton. + auto trt_rm = TRTResourceManager::instance(); + auto res_mgr = trt_rm->getManager("TRTCalibration"); + TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->LookupOrCreate( + funcdef_name_, "Calibrator", &calib_res, + {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }}); + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + int num_inputs = ctx->num_inputs(); + // Pass input data to calibrator + std::unordered_map input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + if (data_address == nullptr) { + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i)); + return; + } + // Check the allocated buffer is sufficient for input + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + input_data.emplace(StrCat(kInputPHName, i), data_address); + } + VLOG(2) << "Filled map for sending"; + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(ctx->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + calib_res->calibrator_->setBatch(input_data, *stream); + VLOG(2) << "Passed calibration data"; + ExecuteNativeSegment(ctx, helper); +} - if (!trt_execution_context_ptr_) { - IRuntime* infer = nvinfer1::createInferRuntime(logger); -#if NV_TENSORRT_MAJOR > 3 - auto device = context->device(); - auto dev_allocator = - device->GetAllocator(tensorflow::AllocatorAttributes()); - if (!dev_allocator) { - LOG(FATAL) << "Can't find device allocator for gpu device " - << device->name(); - } - allocator_ = std::make_shared(dev_allocator); - infer->setGpuAllocator(allocator_.get()); -#endif - trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine_.c_str(), serialized_engine_.size(), - PluginFactoryTensorRT::GetInstance())); - trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); - // Runtime is safe to delete after engine creation - infer->destroy(); - serialized_engine_.clear(); +int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) { + int num_batch = ctx->input(0).shape().dim_size(0); + int smallest_engine = 0; + for (const auto i : cached_engine_batches_) { + if (i >= num_batch) { + smallest_engine = i; + break; + } } - int num_binding = context->num_inputs() + context->num_outputs(); - std::vector buffers(num_binding); + // TODO(sami): Need an LRU here + if (smallest_engine == 0) { + if (max_cached_engines_ > cached_engine_batches_.size()) { + smallest_engine = num_batch; + cached_engine_batches_.push_back(num_batch); + VLOG(1) << "Running with batch size " << num_batch; + } else { + string s("Engine buffer is full. buffer limit= "); + StrAppend(&s, max_cached_engines_, ", current entries= "); + for (auto i : cached_engine_batches_) StrAppend(&s, i, ", "); + StrAppend(&s, "Requested batch= ", num_batch); + LOG(ERROR) << s; + ctx->SetStatus(tensorflow::errors::ResourceExhausted( + "Requested batch size is not available and engine cache is full")); + return -1; + } + } + return smallest_engine; +} - size_t binding_index; - int num_batch = 0; - for (int i = 0; i < context->num_inputs(); i++) { - // Grab the input tensor - binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); +void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, + tensorflow::AsyncOpKernel::DoneCallback done) { + auto helper = new AsyncHelper(done); + tensorflow::core::ScopedUnref sc(helper); + if (calibration_mode_) { + ExecuteCalibration(ctx, helper); + return; + } + const int smallest_engine = GetEngineBatch(ctx); + if (smallest_engine < 0) return; // GetEngineBatch already set the status. + + const int num_batch = ctx->input(0).shape().dim_size(0); + auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); + auto& trt_engine_ptr = engine_ctx_pair.first; + if (!trt_engine_ptr) { + LOG(WARNING) << "Engine retrieval for batch size " << num_batch + << " failed Running native segment"; + ExecuteNativeSegment(ctx, helper); + return; + } - const Tensor& input_tensor = context->input(i); + const int num_binding = ctx->num_inputs() + ctx->num_outputs(); + std::vector buffers(num_binding); + for (int i = 0; i < ctx->num_inputs(); i++) { + const string inp_name = StrCat(kInputPHName, i); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(inp_name.c_str()); + + const Tensor& input_tensor = ctx->input(i); const TensorShape& input_shape = input_tensor.shape(); - if (i == 0) { - num_batch = input_shape.dim_size(0); - if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { - LOG(FATAL) << "input tensor batch larger than max_batch_size: " - << trt_engine_ptr_->getMaxBatchSize(); - } - } else if (num_batch != input_shape.dim_size(0)) { - LOG(FATAL) << "input data inconsistent batch size"; - break; + if (num_batch != input_shape.dim_size(0)) { + LOG(ERROR) << "input data inconsistent batch size"; + ctx->SetStatus(tensorflow::errors::FailedPrecondition( + "Different batch sizes between input tensors")); + return; } - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = (void*)(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(ERROR) << "FP16 inputs are not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "FP16 inputs are not supported!")); + return; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; - break; + LOG(ERROR) << "INT8 inputs are not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "INT8 inputs are not supported!")); + return; default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(ERROR) << "Unknown TRT data type: " << int(dtype); + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unknown output TRT data type! ", static_cast(dtype))); + return; } } - for (int i = 0; i < static_cast(output_nodes_.size()); i++) { - // This is bad that we have to reallocate output buffer every run. + for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor - binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str()); + const string output_name = StrCat(kOutputPHName, i); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; if (binding_index != -1) { - auto dims = trt_engine_ptr_->getBindingDimensions(binding_index); + auto dims = trt_engine_ptr->getBindingDimensions(binding_index); std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; - OP_REQUIRES_OK(context, - TensorShapeUtils::MakeShape( - trt_shape.data(), trt_shape.size(), &output_shape)); + OP_REQUIRES_OK( + ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), + &output_shape)); } else { - LOG(FATAL) << "output node not found, at " << output_nodes_[i]; - break; + LOG(ERROR) << "output node not found, at " << output_name; + ctx->SetStatus(tensorflow::errors::Internal("output ", output_name, + " couldn't be found!")); + return; } - - OP_REQUIRES_OK(context, - context->allocate_output(i, output_shape, &output_tensor)); - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto status = ctx->allocate_output(i, output_shape, &output_tensor); + if (!status.ok()) { + LOG(ERROR) << "Allocating output failed with " << status; + ctx->SetStatus(status); + return; + } + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = reinterpret_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(ERROR) << "half size is not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Half outputs are not supported!")); + return; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; - break; + LOG(ERROR) << "int8 is not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "INT8 outputs are not supported!")); + return; default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unsupported output data type! ", static_cast(dtype))); + return; } } // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(context->op_device_context() + reinterpret_cast(ctx->op_device_context() ->stream() ->implementation() ->CudaStreamMemberHack())); // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], - *stream, nullptr); - VLOG(2) << "enqueue returns: " << ret; + auto& trt_execution_context_ptr = engine_ctx_pair.second; + auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, + nullptr); + if (!ret) { + LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name(); + ctx->SetStatus(tensorflow::errors::Internal( + "Failed to enqueue batch for TRT engine: ", name())); + } // sync should be done by TF. } + TRTEngineOp::~TRTEngineOp() { - // Order matters! - trt_execution_context_ptr_.reset(); - trt_engine_ptr_.reset(); + // We need to manually destroy the engine and execution context before + // the allocator is destructed. + for (auto& eng : engine_map_) { + eng.second.first.reset(); + eng.second.second.reset(); + } allocator_.reset(); } + +nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { + if (allocator_) return allocator_.get(); + auto device = ctx->device(); + auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + ctx->SetStatus(tensorflow::errors::Internal( + "Can't get device allocator for device ", device->name())); + return nullptr; + } + allocator_.reset(new TRTDeviceAllocator(alloc)); + return allocator_.get(); +} + +TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, + OpKernelContext* ctx) { + static EngineCtxPair null_pair = { + TrtUniquePtrType(nullptr), + TrtUniquePtrType(nullptr)}; + // TODO(sami): This method needs to be re-written to use resource manager and + // with LRU mechanism option. + tensorflow::mutex_lock lock(engine_mutex_); + + if (static_engine_) { + if (engine_map_.size()) { + if (engine_map_.begin()->first >= batch_size) { + return engine_map_.begin()->second; + } + return null_pair; + } + TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); +#if NV_TENSORRT_MAJOR > 3 + auto allocator = GetAllocator(ctx); + if (allocator == nullptr) { + // GetAllocator already set the Status. + return null_pair; + } + infer->setGpuAllocator(allocator); +#endif + TrtUniquePtrType static_engine( + infer->deserializeCudaEngine(serialized_segment_.c_str(), + serialized_segment_.size(), nullptr)); + auto raw_static_engine = static_engine.get(); + const auto max_batch_size = raw_static_engine->getMaxBatchSize(); + engine_map_[max_batch_size] = { + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext())}; + // Runtime is safe to delete after engine creation + serialized_segment_.clear(); + if (max_batch_size < batch_size) return null_pair; + return engine_map_.at(max_batch_size); + } // static_engine_ + + // Handle the dynamic engine case. + auto engine_it = engine_map_.find(batch_size); + if (engine_it == engine_map_.end() && + engine_map_.size() < (size_t)max_cached_engines_) { + nvinfer1::IGpuAllocator* allocator = nullptr; +#if NV_TENSORRT_MAJOR > 3 + allocator = GetAllocator(ctx); + if (allocator == nullptr) { + // GetAllocator already set the Status. + return null_pair; + } +#endif + std::vector shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + shapes.emplace_back(ctx->input(i).shape()); + } + TrtUniquePtrType engine; + bool convert_successfully = false; + VLOG(0) << name() << " Constructing a new engine with batch size " + << batch_size; + // Up to this point, calibrator_ can never be empty, since otherwise it + // means calibration_mode_ is true and this path won't get executed. + auto status = convert::ConvertGraphDefToEngine( + segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, + &logger, allocator, calibrator_.get(), &engine, &convert_successfully); + if (!status.ok()) { + if (convert_successfully) { + // This means it fail to build the engine even when the network is built + // successfully, probably due to internal issues. In this case we don't + // retry in the future. + engine_map_[batch_size] = {nullptr, nullptr}; + } + LOG(ERROR) << "Engine creation for batch size " << batch_size + << " failed " << status; + ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!")); + return null_pair; + } + VLOG(1) << "Conversion is done"; + TrtUniquePtrType exec_context( + engine->createExecutionContext()); + engine_map_[batch_size] = {std::move(engine), std::move(exec_context)}; + } + return engine_map_.at(batch_size); +} + +tensorflow::Status TRTEngineOp::AllocateCalibrationResources( + tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) { + auto cres = new TRTCalibrationResource(); + *cr = cres; + // Get the allocator. + auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(WARNING) << "Can't get device allocator will not be able to " + "allocate memory from TensorFlow memory pool"; + cres->allocator_.reset(new TRTCudaAllocator); + } else { + cres->allocator_.reset(new TRTDeviceAllocator(alloc)); + } + // Get the input shapes. + const int batch_size = ctx->input(0).dim_size(0); + const int num_inputs = ctx->num_inputs(); + std::vector shapes; + dev_tensors_.resize(num_inputs); + VLOG(1) << " Constructing calibrator"; + for (int i = 0; i < num_inputs; i++) { + // allocate workspace on device for inputs + const tensorflow::Tensor& t = ctx->input(i); + shapes.emplace_back(t.shape()); + Tensor* device_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor)); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + void* device_address = GetTensorAddress(device_tensor); + if (device_address == nullptr) { + return tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i); + } + device_buffers_.emplace( + StrCat(kInputPHName, i), + std::pair(device_address, device_tensor->TotalBytes())); + } + cres->calibrator_.reset( + new TRTInt8Calibrator(device_buffers_, batch_size, name())); + const string label(name()); + auto segment_graph = &segment_graph_; + const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id; + if (cuda_gpu_id < 0) { + LOG(ERROR) << "Can't get gpu_device_info from context->device()"; + return tensorflow::errors::InvalidArgument( + "Context->device doesn't contain device info!"); + } + const int64 workspace_size_bytes = workspace_size_; + cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes, + cuda_gpu_id, workspace_size_bytes]() { + VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id + << ", Calibration Resource @ " << cres; + auto err = cudaSetDevice(cuda_gpu_id); + if (err != cudaSuccess) { + // TODO(aaroey): should return error here. + LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id + << " in calibration thread"; + } + // ConvertGraphDefToEngine() will try to build the engine. This thread + // will loop inside buildCudaEngine() consuming the calibration data + // that is set by the TF op, and drive the builder until calibrator returns + // false. Engine is discarded after calibration table is generated + // + // TODO(aaroey): maybe setting the max batch size using the python + // calibration wrapper class. + auto s = convert::ConvertGraphDefToEngine( + *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(), + workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), + cres->calibrator_.get(), &cres->engine_, + /*convert_successfully=*/nullptr); + if (!s.ok()) { + LOG(ERROR) << "Calibration failed: " << s; + cres->calibrator_->setDone(); // Ignore further pushes + } + VLOG(1) << "Calibration loop terminated " << label; + })); + VLOG(1) << "initialized calibrator resource"; + return tensorflow::Status::OK(); +} + REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index e613a71422852e60565ba7554516d7eace6b9cc7..6fe318be6a6bc9f01ce3b52e0430f2090b53002b 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -19,9 +19,14 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -30,32 +35,95 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class Logger; - +class TRTInt8Calibrator; +class TRTCalibrationResource; +class AsyncHelper; // TODO(Sami): Remove this file? -class TRTEngineOp : public OpKernel { + +// This OP can construct TRTEngine on the fly and if construction of engine +// fails, executes equivalent subgraph as a TensorFlow function. +class TRTEngineOp : public AsyncOpKernel { public: explicit TRTEngineOp(OpKernelConstruction* context); - void Compute(OpKernelContext* context) override; + void ComputeAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; ~TRTEngineOp(); private: - template - struct Destroyer { - void operator()(T* d) { d->destroy(); } - }; - - template - using destroyed_ptr = std::unique_ptr>; - destroyed_ptr trt_engine_ptr_; + // Execute calibration + void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph + Status ConstructFunctionHandle(OpKernelContext* ctx); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + + // Allocate necessary resources for calibration + Status AllocateCalibrationResources(OpKernelContext* ctx, + TRTCalibrationResource** cr); + // TODO(samikama): context should go to a resource manager! - destroyed_ptr trt_execution_context_ptr_; + typedef std::pair, + TrtUniquePtrType> + EngineCtxPair; + EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx); + // Return engine batch closest to input batch. + int GetEngineBatch(OpKernelContext* ctx); + + nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx); + + // map to keep engines and their execution context for given batch size. + std::unordered_map engine_map_; std::vector input_nodes_; std::vector output_nodes_; - std::shared_ptr allocator_; - string serialized_engine_; + + // keep device allocator for TRT. + std::unique_ptr allocator_; + + // serialized protobuf segment or trt engine depending on static_engine_ flag. + string serialized_segment_; + + // Name of the function for TF native execution of the segment. + string funcdef_name_; + + // GraphDef representation of the segment. + GraphDef segment_graph_; + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector dev_tensors_; + + // Engine Precision mode. + int precision_mode_; + + // Whether engine is constructed during the conversion or needs to be + // constructed from protobuf segment. + bool static_engine_; + + // Whether to calibrate INT8 engine. + bool calibration_mode_; + + // Whether non-batch ranks of the inputs are assumed to be fixed or not for + // engine construction. + bool fixed_input_size_; + + // Batches of the cached engines + std::vector cached_engine_batches_; + + // Maximum number of cached engines + int max_cached_engines_; + + int64 workspace_size_; + mutex engine_mutex_; + FunctionLibraryRuntime::Handle native_func_; + + // The finalized calibrator for inference. + std::unique_ptr calibrator_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc deleted file mode 100644 index 4835e5065068ec7a59995eb7f6126b31aecf6704..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -namespace tensorflow { - -REGISTER_OP("TRTCalibOp") - .Attr("segment_nodes: list(string)") // names of the ops in segment - .Attr("segment_output_names: list(string)") // names of the output ops in - // segment - .Attr("input_names: list(string)") // names of the inputs for - // passing into tensorrt - .Attr("resource_name: string") - .Attr("InT: list({int8, float16, float32})") - .Input("in_tensor: InT") - .Output("out_tensor: InT") - .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_inputs(); i++) { - c->set_output(i, c->input(i)); - } - return Status::OK(); - }); - -} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index 079d73f7bec3f9a9740e455b31a259cec287f849..383635f428812984915a8c46ad3b92cc7b28a5f7 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -28,11 +28,19 @@ extern Status TRTEngineOpShapeInference(InferenceContext* c); } REGISTER_OP("TRTEngineOp") - .Attr("serialized_engine: string") - .Attr("input_nodes: list(string)") - .Attr("output_nodes: list(string)") - .Attr("InT: list({float32})") - .Attr("OutT: list({float32})") + .Attr("serialized_segment: string") + .Attr("input_shapes: list(shape)") + .Attr("output_shapes: list(shape)") + .Attr("segment_funcdef_name: string") + .Attr("InT: list({int8,float16,float32})") + .Attr("OutT: list({int8,float16,float32})") + .Attr("static_engine: bool = true") + .Attr("fixed_input_size: bool = true") + .Attr("cached_engine_batches: list(int) = []") + .Attr("max_cached_engines_count: int = 1") + .Attr("workspace_size_bytes: int") + .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") + .Attr("calibration_data: string = ''") .Input("in_tensor: InT") .Output("out_tensor: OutT") .SetShapeFn(shape_inference::TRTEngineOpShapeInference); diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 338475d90ea55ab2c1bb8df77f27a71a4a36a5dd..79f512dbcf6bd4d84b98cf69630778734566391c 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -21,6 +21,8 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert +from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -29,7 +31,9 @@ from tensorflow.python.framework import errors_impl as _impl from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.platform import tf_logging from tensorflow.python.util import compat + # pylint: enable=unused-import,line-too-long @@ -40,7 +44,10 @@ def create_inference_graph(input_graph_def, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", - minimum_segment_size=3): + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]): """Python wrapper for the TRT transformation. Args: @@ -51,6 +58,10 @@ def create_inference_graph(input_graph_def, precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. + is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT + network and engine at run time. + maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. + cached_engine_batches: batch sizes used to pre-create cached engines. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. @@ -65,6 +76,30 @@ def create_inference_graph(input_graph_def, "It should be one of {}").format( precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] + compiled_version = get_linked_tensorrt_version() + loaded_version = get_loaded_tensorrt_version() + version_mismatch = False + if loaded_version[0] < compiled_version[0]: + tf_logging.error( + "TensorRT version mismatch. Tensorflow was compiled against " + + "TensorRT %s but library loaded from environment is TensorRT %s" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version])) + + ". Please make sure that correct version of TensorRT " + + "is available in the system and added to ldconfig or LD_LIBRARY_PATH" + ) + raise RuntimeError("Incompatible TensorRT library version") + for i in zip(loaded_version, compiled_version): + if i[0] != i[1]: + tf_logging.warn("TensorRT mismatch. Compiled against version " + + "%s, but loaded %s. Things may not work" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version]))) + version_mismatch = True + break + if not version_mismatch: + tf_logging.info("Running against TensorRT version %s" % ".".join( + [str(x) for x in loaded_version])) def py2bytes(inp): return inp @@ -100,7 +135,9 @@ def create_inference_graph(input_graph_def, # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, - max_workspace_size_bytes, mode, minimum_segment_size) + max_workspace_size_bytes, mode, minimum_segment_size, + is_dynamic_op, maximum_cached_engines, + cached_engine_batches) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory @@ -120,11 +157,12 @@ def create_inference_graph(input_graph_def, return output_graph_def -def calib_graph_to_infer_graph(calibration_graph_def): +def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data + is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: @@ -141,9 +179,16 @@ def calib_graph_to_infer_graph(calibration_graph_def): to_string = py2string else: to_string = py3string - + is_calib_graph = False + for n in calibration_graph_def.node: + if n.op == "TRTEngineOp": + is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s + if not is_calib_graph: + tf_logging.error( + "Not a calib graph. Doesn't seem to contain any calibration nodes.") + return None graph_str = calibration_graph_def.SerializeToString() - out = calib_convert(graph_str) + out = calib_convert(graph_str, is_dynamic_op) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc index 0f0508331c13055096714352e83fc360f0ef39b4..9f115990c3a3e6e92093e5f0d82b985af1b25482 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc @@ -50,7 +50,7 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator) } void TRTDeviceAllocator::free(void* memory) { - VLOG(2) << "Deallocating " << memory; + VLOG(2) << "Deallocating @ " << memory; allocator_->DeallocateRaw(memory); } diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index a0c2540a7698bc46a65dbd967412351bac2a4dd2..c5d2cec730f4ae97e4c6bcc19897fd9f321122a7 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ #define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ - #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/framework/allocator.h" @@ -52,7 +51,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { // Allocator implementation wrapping TF device allocators. public: TRTDeviceAllocator(tensorflow::Allocator* allocator); - virtual ~TRTDeviceAllocator() {} + virtual ~TRTDeviceAllocator() { + VLOG(1) << "Destroying allocator attached to " << allocator_->Name(); + } void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override; void free(void* memory) override; diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index dc7c93f869f5ef7c8eaa2a87eed26cfe69597fdb..32e81858b95d76a2baebb4804a1326fbbb6144c7 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include -#include #include #include "tensorflow/core/platform/logging.h" @@ -37,15 +36,22 @@ TRTInt8Calibrator::TRTInt8Calibrator( : batch_size_(batch_size), done_(false), dev_buffers_(dev_buffers), - calib_running_(false), + calib_running_(true), batch_is_set_(false), engine_name_(engine_name) {} +TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) + : batch_size_(0), + done_(false), + calib_running_(false), + batch_is_set_(false), + calibration_table_(calib_data) {} + bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - while ((calib_running_ || batch_is_set_) && - !done_) { // wait while calibration is running + // wait while calibration is running. + while ((calib_running_ || batch_is_set_) && !done_) { cond_.wait(lock); } if (done_) return false; @@ -59,8 +65,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, } const auto& d = devptr->second; - // TODO(aaroey): we should not use sync copy on default stream. Make sure - // stream->ThenMemcpy() is used in future PRs. // TODO(sami,aaroey): Need to figure out a way to ensure synchronization // between stream, perhaps using a tensor? auto status = cudaMemcpyAsync(d.first, it.second, d.second, @@ -84,13 +88,11 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, tensorflow::mutex_lock lock(cond_mtx_); calib_running_ = false; cond_.notify_all(); - while ((!batch_is_set_ && !done_)) { // wait until new batch arrives + // wait until new batch arrives + while ((!batch_is_set_ && !done_)) { cond_.wait(lock); - - } - if (done_) { - return false; } + if (done_) return false; for (int i = 0; i < num_bindings; i++) { auto it = dev_buffers_.find(names[i]); @@ -107,7 +109,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, } const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { - return nullptr; + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); } void TRTInt8Calibrator::setDone() { @@ -117,7 +121,11 @@ void TRTInt8Calibrator::setDone() { } void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, - std::size_t length) {} + std::size_t length) { + calibration_table_ = string((const char*)ptr, length); + VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr + << " length=" << length; +} TRTInt8Calibrator::~TRTInt8Calibrator() { VLOG(1) << "Destroying calibrator for " << engine_name_; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index d77aa2c5ab184756adaee38f88180b3c128ebe03..994312d7c3c93ba04394b7d9542d261c57c5609b 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -39,29 +39,48 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { TRTInt8Calibrator( const std::unordered_map>& dev_buffers, int batch_size, string engine_name); + + TRTInt8Calibrator(const string& calibration_data); + + ~TRTInt8Calibrator(); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], int num_bindings) override; + bool setBatch(const std::unordered_map& data, const cudaStream_t stream); + void setDone(); + + // If not null, calibration is skipped. const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; - ~TRTInt8Calibrator(); + + const string& getCalibrationTableAsString() { return calibration_table_; } private: const int batch_size_; - tensorflow::mutex cond_mtx_; // mutex for condition_variable - tensorflow::condition_variable cond_; // condition variable to implement - // producer-consumer queue for - // calibration + + // mutex for condition_variable + tensorflow::mutex cond_mtx_; + + // condition variable to implement producer-consumer queue for calibration + tensorflow::condition_variable cond_; + + // Is calibration finished? bool done_; - const std::unordered_map> - dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with - // buffer names + + // Map to keep tensorrt input buffers and sizes keyed with buffer names + const std::unordered_map> dev_buffers_; + bool calib_running_; bool batch_is_set_; + string engine_name_; + string calibration_table_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index e3469124acd4b9f6f4dd81b9998aa60bfe469b35..b7d5ffd6748ba34c6c4ddbfbfbb44edb6bf2aca8 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" @@ -34,50 +35,48 @@ limitations under the License. namespace tensorflow { namespace tensorrt { + class TRTCalibrationResource : public tensorflow::ResourceBase { public: - TRTCalibrationResource() - : calibrator_(nullptr), - builder_(nullptr), - network_(nullptr), - engine_(nullptr), - logger_(nullptr), - thr_(nullptr) {} - ~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + builder_.reset(); + engine_.reset(); + // We need to manually destroy the builder and engine before the allocator + // is destroyed. + allocator_.reset(); } string DebugString() override { std::stringstream oss; - oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl - << " Builder = " << std::hex << builder_ << std::dec << std::endl - << " Network = " << std::hex << network_ << std::dec << std::endl - << " Engine = " << std::hex << engine_ << std::dec << std::endl - << " Logger = " << std::hex << logger_ << std::dec << std::endl - << " Allocator = " << std::hex << allocator_.get() << std::dec - << std::endl - << " Thread = " << std::hex << thr_ << std::dec << std::endl; + using std::dec; + using std::endl; + using std::hex; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; return oss.str(); } - TRTInt8Calibrator* calibrator_; - nvinfer1::IBuilder* builder_; - nvinfer1::INetworkDefinition* network_; - nvinfer1::ICudaEngine* engine_; - std::shared_ptr allocator_; - tensorflow::tensorrt::Logger* logger_; + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + std::unique_ptr allocator_; + tensorflow::tensorrt::Logger logger_; // TODO(sami): Use threadpool threads! - std::thread* thr_; + std::unique_ptr thr_; }; -class TRTWeightStore : public tensorflow::ResourceBase { +class TRTWeightStore { public: TRTWeightStore() {} virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } - string DebugString() override { + string DebugString() { std::stringstream oss; size_t len_bytes = 0; for (const auto& v : store_) { diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 1568dd915344e6ba982b5a5550cc5386e047ff9f..81b4bfe49fe375d19f4c7811459f38e25d2edea8 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -29,8 +29,9 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// vector of segments, each entry contains a device name and a set of nodes in -// segment +// Vector of segments, each entry contains a set of node names and a device name +// in the segment. +// TODO(aaroey): use node pointer instead of node name. using SegmentNodesVector = std::vector, string>>; struct SegmentOptions { @@ -48,6 +49,8 @@ struct SegmentOptions { // in the vector describes a subgraph by giving a set of the names of // all the NodeDefs in that subgraph. // @return the status. +// +// TODO(aaroey): remove this method. tensorflow::Status SegmentGraph( const tensorflow::GraphDef& gdef, const std::function& candidate_fn, diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f36495f6b69ecb2f2a8d730b9ae4919fea3c04b8..227ac120dde8c986379c687987cd1bd822d559f7 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -29,61 +29,35 @@ namespace tensorflow { namespace shape_inference { tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { - tensorflow::tensorrt::Logger logger; - string serialized_engine; - TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); - nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); - nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), - tensorrt::PluginFactoryTensorRT::GetInstance()); - - int num_batch = -1; - std::vector<::tensorflow::DataType> input_type; - TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type)); - for (size_t i = 0; i < context->num_inputs(); i++) { - // Check if input shape is legit - auto input_shape = context->input(i); - for (int j = 0; j < context->Rank(input_shape); j++) { - auto dim_handler = context->Dim(input_shape, j); - if (j == 0) { - if (i == 0) { - num_batch = context->Value(dim_handler); - } else if (num_batch != context->Value(dim_handler)) { - // TODO(jie): TensorRT engine requires consistent batch between inputs - // tensors. Segmenter should be aware of this. - LOG(FATAL) << "TensorRT engine requires consistent batch size"; - } - } - } + std::vector shapes; + for (int i = 0; i < context->num_outputs(); ++i) { + context->set_output(i, context->UnknownShape()); } - - // Arrange input here - std::vector input_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes)); - - // Arrange output here - std::vector output_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes)); - for (size_t i = 0; i < output_nodes.size(); i++) { - int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str()); - ShapeHandle output_shape; - std::vector dim_vec; - dim_vec.emplace_back(context->MakeDim(num_batch)); - if (binding_index != -1) { - auto dims = trt_engine->getBindingDimensions(binding_index); - for (int j = 0; j < dims.nbDims; j++) { - dim_vec.emplace_back(context->MakeDim(dims.d[j])); - } - } else { - LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i]; - } - output_shape = context->MakeShape(dim_vec); - context->set_output(i, output_shape); + auto status = context->GetAttr("input_shapes", &shapes); + // it is ok to not to have shapes + if (!status.ok()) return Status::OK(); + if ((int)shapes.size() != context->num_inputs()) return Status::OK(); + bool different_input = false; + for (int i = 0; i < context->num_inputs(); ++i) { + if (shapes.at(i) != context->input_tensor(i)->shape()) + different_input = true; + } + if (different_input) return Status::OK(); + shapes.resize(0); + status = context->GetAttr("output_shapes", &shapes); + if (!status.ok()) return Status::OK(); + if ((int)shapes.size() != context->num_outputs()) return Status::OK(); + std::vector shape_handles(shapes.size()); + for (size_t i = 0; i < shapes.size(); ++i) { + status = + context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i)); + if (!status.ok()) return Status::OK(); + } + for (int i = 0; i < context->num_outputs(); ++i) { + context->set_output(i, shape_handles.at(i)); } - return Status::OK(); } - } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 175ccd800686255092e241aa59568df407d6eebc..090aa8bdb0487973e186631af3b4edac48096a5f 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import numpy as np +import six as _six # normally we should do import tensorflow as tf and then # tf.placeholder, tf.constant, tf.nn.conv2d etc but @@ -35,10 +36,75 @@ from tensorflow.python.framework import dtypes as dtypes from tensorflow.python.framework import importer as importer from tensorflow.python.framework import ops as ops from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import math_ops as mops from tensorflow.python.ops import nn as nn from tensorflow.python.ops import nn_ops as nn_ops +def py2bytes(inp): + return inp + + +def py3bytes(inp): + return inp.encode("utf-8", errors="surrogateescape") + + +def py2string(inp): + return inp + + +def py3string(inp): + return inp.decode("utf-8") + + +if _six.PY2: + to_bytes = py2bytes + to_string = py2string +else: + to_bytes = py3bytes + to_string = py3string + + +def get_multi_engine_graph_def(mode="FP32"): + """Create a simple graph and return its graph_def.""" + dtype = dtypes.float32 + if mode.upper() == "FP16": + dtype = dtypes.float16 + else: + pass + + g = ops.Graph() + with g.as_default(): + x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype) + with g.name_scope("Global_scope"): + with g.name_scope("first_scope"): + e = cop.constant( + np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype) + t = conv * b + + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype) + q = conv / b + edge = mops.sin(q) + edge1 = mops.cos(conv) + with g.name_scope("test_scope"): + de = edge + edge1 + t -= edge1 + q *= edge + t += q + t -= de + k = aops.squeeze(t, name="output") + print(k.dtype) + return g.as_graph_def() + + def get_simple_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() @@ -65,7 +131,9 @@ def get_simple_graph_def(): def execute_graph(gdef, dumm_inp): """Run given graphdef once.""" print("executing") - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) ops.reset_default_graph() g = ops.Graph() @@ -83,7 +151,9 @@ def execute_graph(gdef, dumm_inp): # for calibration. For this test script it is random data. def execute_calibration(gdef, dumm_inp): """Run given calibration graph multiple times.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() with g.as_default(): @@ -100,12 +170,17 @@ def execute_calibration(gdef, dumm_inp): return val -def user(run_graph=execute_graph, run_calibration=execute_calibration): +def user(multi_engine, + run_graph=execute_graph, + run_calibration=execute_calibration): """Example function that converts a graph to TFTRT graph.""" - - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() # use a frozen graph for inference # Get optimized graph trt_graph = trt.create_inference_graph( input_graph_def=orig_graph, @@ -113,8 +188,10 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) @@ -126,40 +203,51 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o4 = run_graph(fp16_graph, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) o5 = run_graph(int8_graph, dummy_input) - assert np.allclose(o1, o4) - assert np.allclose(o1, o5) + print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4)) + print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5)) print("Pass") -def auto(): +def auto(multi_engine): """Run the conversion as an optimization pass.""" - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() opt_config = rwpb2.RewriterConfig() + opt_config.meta_optimizer_iterations = opt_config.ONE opt_config.optimizers.extend(["constfold", "layout"]) custom_op = opt_config.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["precision_mode"].s = "FP32" + custom_op.parameter_map["precision_mode"].s = to_bytes("FP32") custom_op.parameter_map["max_batch_size"].i = inp_dims[0] custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 print(custom_op) - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) graph_options = cpb2.GraphOptions(rewrite_options=opt_config) sessconfig = cpb2.ConfigProto( gpu_options=gpu_options, graph_options=graph_options) @@ -168,7 +256,7 @@ def auto(): ops.reset_default_graph() with g.as_default(): inp, out = importer.import_graph_def( - graph_def=orig_graph, return_elements=["input", "output"]) + graph_def=orig_graph, return_elements=["input", "output"], name="") inp = inp.outputs[0] out = out.outputs[0] with csess.Session(config=sessconfig, graph=g) as sess: @@ -186,8 +274,14 @@ if "__main__" in __name__: action="store_true", help="Do TRT conversion automatically", default=False) + P.add_argument( + "--multi-engine", + "-m", + action="store_true", + help="Use a graph that will result in 2 engines", + default=False) flags, unparsed = P.parse_known_args() if flags.automatic: - auto() + auto(flags.multi_engine) else: - user() + user(flags.multi_engine) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py index 0403b652d72877196c3537a3181529aeeb997395..d9c41f90d0ab111b48c37aeaae5f0ce3177646c2 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py @@ -18,131 +18,330 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple +import itertools import warnings import numpy as np +import six from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.python.framework import constant_op as cop -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops -from tensorflow.python.platform import googletest +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test +INPUT_NAME = "input" +OUTPUT_NAME = "output" +INPUT_DIMS = [100, 24, 24, 2] +MODE_FP32 = "FP32" +MODE_FP16 = "FP16" +MODE_INT8 = "INT8" -class IntegrationTest(test_util.TensorFlowTestCase): +if six.PY2: + to_bytes = lambda s: s + to_string = lambda s: s +else: + to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") + to_string = lambda s: s.decode("utf-8") + + +# TODO(aaroey): test graph with different dtypes. +def GetSingleEngineGraphDef(dtype=dtypes.float32): + """Create a graph containing single segment.""" + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + bias = constant_op.constant( + [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) + added = nn.bias_add(conv, bias, name="bias_add") + relu = nn.relu(added, "relu") + identity = array_ops.identity(relu, "identity") + pool = nn_ops.max_pool( + identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + array_ops.squeeze(pool, name=OUTPUT_NAME) + return g.as_graph_def() + + +# TODO(aaroey): test graph with different dtypes. +def GetMultiEngineGraphDef(dtype=dtypes.float32): + """Create a graph containing multiple segment.""" + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + c1 = constant_op.constant( + np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype) + p = conv * c1 + c2 = constant_op.constant( + np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype) + q = conv / c2 + + edge = math_ops.sin(q) + edge /= edge + r = edge + edge + + p -= edge + q *= edge + s = p + q + s -= r + array_ops.squeeze(s, name=OUTPUT_NAME) + return g.as_graph_def() + + +TestGraph = namedtuple("TestGraph", + ["gdef", "num_expected_engines", "expected_output_dims"]) + +TEST_GRAPHS = { + "SingleEngineGraph": + TestGraph( + gdef=GetSingleEngineGraphDef(), + num_expected_engines=1, + expected_output_dims=(100, 6, 6, 6)), + "MultiEngineGraph": + TestGraph( + gdef=GetMultiEngineGraphDef(), + num_expected_engines=2, + expected_output_dims=(100, 12, 12, 6)), + # TODO(aaroey): add a large complex graph to test. +} + + +class TfTrtIntegrationTest(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" def setUp(self): """Setup method.""" - super(IntegrationTest, self).setUp() + super(TfTrtIntegrationTest, self).setUp() warnings.simplefilter("always") - inp_dims = (100, 24, 24, 2) - self._input = np.random.random_sample(inp_dims) - self._original_graph = self.get_simple_graph_def() - self._gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - self._config = cpb2.ConfigProto(gpu_options=self._gpu_options) - self._reference = self.run_graph(self._original_graph, self._input) - - def get_simple_graph_def(self): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = aops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - e = cop.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtypes.float32) - conv = nn.conv2d( - input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") - b = cop.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) - t = nn.bias_add(conv, b, name="biasAdd") - relu = nn.relu(t, "relu") - idty = aops.identity(relu, "ID") - v = nn_ops.max_pool( - idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - aops.squeeze(v, name="output") - return g.as_graph_def() - - def run_graph(self, gdef, dumm_inp): - """Run given graphdef once.""" - ops.reset_default_graph() + self._input = np.random.random_sample(INPUT_DIMS) + + def _GetConfigProto(self, + use_optimizer, + precision_mode=None, + is_dynamic_op=None): + if use_optimizer: + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + custom_op = rewriter_cfg.custom_optimizers.add() + custom_op.name = "TensorRTOptimizer" + custom_op.parameter_map["minimum_segment_size"].i = 3 + custom_op.parameter_map["max_batch_size"].i = self._input.shape[0] + custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op + custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 + custom_op.parameter_map["precision_mode"].s = to_bytes(precision_mode) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) + else: + graph_options = config_pb2.GraphOptions() + + gpu_options = config_pb2.GPUOptions() + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options.per_process_gpu_memory_fraction = 0.50 + + config = config_pb2.ConfigProto( + gpu_options=gpu_options, graph_options=graph_options) + return config + + def _RunGraph(self, graph_key, gdef, input_data, config, num_runs=2): + """Run given graphdef multiple times.""" g = ops.Graph() with g.as_default(): inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) + graph_def=gdef, return_elements=[INPUT_NAME, OUTPUT_NAME], name="") inp = inp.outputs[0] out = out.outputs[0] with self.test_session( - graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess: - val = sess.run(out, {inp: dumm_inp}) + graph=g, config=config, use_gpu=True, force_gpu=True) as sess: + val = None + # Defaults to 2 runs to verify result across multiple runs is same. + for _ in range(num_runs): + new_val = sess.run(out, {inp: input_data}) + self.assertEquals(TEST_GRAPHS[graph_key].expected_output_dims, + new_val.shape) + if val is not None: + self.assertAllEqual(new_val, val) + val = new_val return val # Use real data that is representative of the inference dataset # for calibration. For this test script it is random data. - def run_calibration(self, gdef, dumm_inp): - """Run given calibration graph multiple times.""" - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - # run over real calibration data here, we are mimicking a calibration - # set of 30 different batches. Use as much calibration data as you want - with self.test_session( - graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess: - for _ in range(30): - val = sess.run(out, {inp: dumm_inp}) - return val + def _RunCalibration(self, graph_key, gdef, input_data, config): + """Run calibration on given graph.""" + return self._RunGraph(graph_key, gdef, input_data, config, 30) - def get_trt_graph(self, mode): + def _GetTrtGraph(self, gdef, precision_mode, is_dynamic_op): """Return trt converted graph.""" - if mode in ["FP32", "FP16", "INT8"]: - return trt.create_inference_graph( - input_graph_def=self._original_graph, - outputs=["output"], - max_batch_size=self._input.shape[0], - max_workspace_size_bytes=1 << 25, - precision_mode=mode, # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) - return None - - def testFP32(self): - """Test FP32 conversion. Results should be identical to native case.""" - trt_graph = self.get_trt_graph("FP32") - result = self.run_graph(trt_graph, self._input) - self.assertAllEqual(self._reference, result) - result1 = self.run_graph(trt_graph, self._input) - self.assertAllEqual(result1, result) - - def testFP16(self): - """Test FP16 conversion. Results may be different from native case.""" - trt_graph = self.get_trt_graph("FP16") - result = self.run_graph(trt_graph, self._input) - self.assertAllClose(self._reference, result, rtol=1.e-03) - result1 = self.run_graph(trt_graph, self._input) - self.assertAllEqual(result1, result) - - def testINT8(self): - """Test INT8 conversion. Results may be different from native case.""" - calib_graph = self.get_trt_graph("INT8") - result = self.run_calibration(calib_graph, self._input) - self.assertAllEqual(self._reference, result) - int8_graph = trt.calib_graph_to_infer_graph(calib_graph) - result = self.run_graph(int8_graph, self._input) - self.assertAllClose(self._reference, result, rtol=1.e-03) - result1 = self.run_graph(int8_graph, self._input) - self.assertAllEqual(result1, result) + return trt.create_inference_graph( + input_graph_def=gdef, + outputs=[OUTPUT_NAME], + max_batch_size=self._input.shape[0], + max_workspace_size_bytes=1 << 25, + precision_mode=precision_mode, + minimum_segment_size=2, + is_dynamic_op=is_dynamic_op) + + def _VerifyGraphDef(self, + graph_key, + gdef, + precision_mode=None, + is_calibrated=None, + dynamic_engine=None): + num_engines = 0 + for n in gdef.node: + if n.op == "TRTEngineOp": + num_engines += 1 + self.assertNotEqual("", n.attr["serialized_segment"].s) + self.assertNotEqual("", n.attr["segment_funcdef_name"].s) + self.assertEquals(n.attr["precision_mode"].s, precision_mode) + self.assertEquals(n.attr["static_engine"].b, not dynamic_engine) + if precision_mode == MODE_INT8 and is_calibrated: + self.assertNotEqual("", n.attr["calibration_data"].s) + else: + self.assertEquals("", n.attr["calibration_data"].s) + if precision_mode is None: + self.assertEquals(num_engines, 0) + else: + self.assertEquals(num_engines, + TEST_GRAPHS[graph_key].num_expected_engines) + + def _RunTest(self, graph_key, use_optimizer, precision_mode, + dynamic_infer_engine, dynamic_calib_engine): + assert precision_mode in [MODE_FP32, MODE_FP16, MODE_INT8] + input_gdef = TEST_GRAPHS[graph_key].gdef + self._VerifyGraphDef(graph_key, input_gdef) + + # Get reference result without running trt. + config_no_trt = self._GetConfigProto(False) + print("Running original graph w/o trt, config:\n%s" % str(config_no_trt)) + ref_result = self._RunGraph(graph_key, input_gdef, self._input, + config_no_trt) + + # Run calibration if necessary. + if precision_mode == MODE_INT8: + + calib_config = self._GetConfigProto(use_optimizer, precision_mode, + dynamic_calib_engine) + print("Running calibration graph, config:\n%s" % str(calib_config)) + if use_optimizer: + self.assertTrue(False) + # TODO(aaroey): uncomment this and get infer_gdef when this mode is + # supported. + # result = self._RunCalibration(graph_key, input_gdef, self._input, + # calib_config) + else: + calib_gdef = self._GetTrtGraph(input_gdef, precision_mode, + dynamic_calib_engine) + self._VerifyGraphDef(graph_key, calib_gdef, precision_mode, False, + dynamic_calib_engine) + result = self._RunCalibration(graph_key, calib_gdef, self._input, + calib_config) + infer_gdef = trt.calib_graph_to_infer_graph(calib_gdef) + self._VerifyGraphDef(graph_key, infer_gdef, precision_mode, True, + dynamic_calib_engine) + self.assertAllClose(ref_result, result, rtol=1.e-03) + else: + infer_gdef = input_gdef + + # Run inference. + infer_config = self._GetConfigProto(use_optimizer, precision_mode, + dynamic_infer_engine) + print("Running final inference graph, config:\n%s" % str(infer_config)) + if use_optimizer: + result = self._RunGraph(graph_key, infer_gdef, self._input, infer_config) + else: + trt_infer_gdef = self._GetTrtGraph(infer_gdef, precision_mode, + dynamic_infer_engine) + self._VerifyGraphDef(graph_key, trt_infer_gdef, precision_mode, True, + dynamic_infer_engine) + result = self._RunGraph(graph_key, trt_infer_gdef, self._input, + infer_config) + self.assertAllClose(ref_result, result, rtol=1.e-03) + + def testIdempotence(self): + # Test that applying tensorrt optimizer or offline conversion tools multiple + # times to the same graph will result in same graph. + # TODO(aaroey): implement this. + pass + + +def GetTests(): + + def _GetTest(g, u, p, i, c): + + def _Test(self): + print("Running test with parameters: graph_key=%s, use_optimizer=%s, " + "precision_mode=%s, dynamic_infer_engine=%s, " + "dynamic_calib_engine=%s" % (g, u, p, i, c)) + self._RunTest(g, u, p, i, c) + + return _Test + + use_optimizer_options = [False, True] + precision_mode_options = [MODE_FP32, MODE_FP16, MODE_INT8] + dynamic_infer_engine_options = [False, True] + dynamic_calib_engine_options = [False, True] + for (graph_key, use_optimizer, precision_mode, + dynamic_infer_engine, dynamic_calib_engine) in itertools.product( + TEST_GRAPHS, use_optimizer_options, precision_mode_options, + dynamic_infer_engine_options, dynamic_calib_engine_options): + if precision_mode == MODE_INT8: + if not dynamic_calib_engine and dynamic_infer_engine: + # TODO(aaroey): test this case, the conversion from static calibration + # engine to dynamic inference engine should be a noop. + continue + if use_optimizer: + # TODO(aaroey): if use_optimizer is True we need to get the inference + # graphdef using custom python wrapper class, which is not currently + # supported yet. + continue + if not dynamic_calib_engine: + # TODO(aaroey): construction of static calibration engine is not + # supported yet. + continue + if dynamic_calib_engine and not dynamic_infer_engine: + # TODO(aaroey): construction of static inference engine using dynamic + # calibration engine is not supported yet. + continue + else: # In non int8 mode. + if dynamic_calib_engine: + # dynamic_calib_engine doesn't affect non-int8 modes, so just let + # related tests run once on dynamic_calib_engine=False. + continue + yield _GetTest(graph_key, use_optimizer, precision_mode, + dynamic_infer_engine, dynamic_calib_engine) if __name__ == "__main__": - googletest.main() + for index, t in enumerate(GetTests()): + setattr(TfTrtIntegrationTest, "testTfTRT_" + str(index), t) + test.main() diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 46480e99a113afb34702b0ecd71468d4bdc83f98..d51a0b59e22cb063b380808f5887538e0294daff 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -48,12 +48,53 @@ PyObject* pair_helper(std::pair* in) { } return tuple; } + +struct version_struct{ + int vmajor; + int vminor; + int vpatch; +}; + +PyObject* version_helper(version_struct* in) { + PyObject *tuple(nullptr); + tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch); + if (!tuple) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "Tuple creation from version structure failed!"); + } + return NULL; + } + return tuple; +} +/* Define converters for vector */ +template<> +bool _PyObjAs(PyObject *pyobj, int* dest) { + *dest = PyLong_AsLong(pyobj); + return true; +} + +template<> +PyObject *_PyObjFrom(const int& src) { + return PyLong_FromLong(src); +} + %} + +_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); + %typemap(out) std::pair { PyObject *tuple = pair_helper(&$1); if (!tuple) SWIG_fail; $result = tuple; } + +%typemap(out) version_struct { + PyObject *tuple = version_helper(&$1); + if (!tuple) SWIG_fail; + $result = tuple; +} + %{ #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -65,6 +106,8 @@ PyObject* pair_helper(std::pair* in) { %unignore tensorflow; %unignore trt_convert; %unignore calib_convert; +%unignore get_linked_tensorrt_version; +%unignore get_loaded_tensorrt_version; %{ @@ -74,7 +117,10 @@ std::pair trt_convert( size_t max_batch_size, size_t max_workspace_size_bytes, int precision_mode, - int minimum_segment_size + int minimum_segment_size, + bool is_dyn_op, + int max_cached_engines, + std::vector cached_engine_batches // Unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -102,11 +148,12 @@ std::pair trt_convert( out_status = "InvalidArgument;Size of the output_names vector is 0"; return std::pair{out_status, ""}; } - tensorflow::GraphDef outGraph; + tensorflow::GraphDef out_graph; tensorflow::Status conversion_status = tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &outGraph, precision_mode, minimum_segment_size); + &out_graph, precision_mode, minimum_segment_size, + is_dyn_op, max_cached_engines, cached_engine_batches); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -116,7 +163,7 @@ std::pair trt_convert( return std::pair{out_status, ""}; } string result; - if (!outGraph.SerializeToString(&result)) { + if (!out_graph.SerializeToString(&result)) { out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; return std::pair{out_status, ""}; } @@ -128,7 +175,8 @@ std::pair trt_convert( #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } -std::pair calib_convert(string graph_def_string // const tensorflow::GraphDef& +std::pair calib_convert( + string graph_def_string, bool is_dyn_op // unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -147,11 +195,11 @@ std::pair calib_convert(string graph_def_string // const tenso out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; return std::pair{out_status, ""}; } - - tensorflow::GraphDef outGraph; + graph_def_string.resize(0); + tensorflow::GraphDef out_graph; tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def, - &outGraph); + tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph( + graph_def, &out_graph, is_dyn_op); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -161,7 +209,7 @@ std::pair calib_convert(string graph_def_string // const tenso return std::pair{out_status, ""}; } string result; - if (!outGraph.SerializeToString(&result)) { + if (!out_graph.SerializeToString(&result)) { out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; return std::pair{out_status, ""}; } @@ -172,15 +220,39 @@ std::pair calib_convert(string graph_def_string // const tenso return std::pair{"9;TensorRT is not enabled!", ""}; #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } + +version_struct get_linked_tensorrt_version(){ + // Return the version at the link time. + const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion(); + version_struct s; + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; + return s; +} +version_struct get_loaded_tensorrt_version(){ + // Return the version from the loaded library. + const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion(); + version_struct s; + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; + return s; +} + %} -std::pair calib_convert(string graph_def_string); +std::pair calib_convert(string graph_def_string, bool is_dyn_op); std::pair trt_convert(string graph_def_string, std::vector output_names, size_t max_batch_size, size_t max_workspace_size_bytes, - int precision_mode, int minimum_segment_size); - + int precision_mode, int minimum_segment_size, + bool is_dyn_op, + int max_cached_engines, + std::vector cached_engine_batches); +version_struct get_linked_tensorrt_version(); +version_struct get_loaded_tensorrt_version(); %unignoreall diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index a28a5872b850b51630240bdeb3ff22f372613523..f236329fdb038ba5ab432c6b97f44bda7ccfe815 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -132,7 +132,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce loss=model_outputs.loss, mode=mode, eval_metric_ops=metrics, - predictions={}) + # needed for custom metrics. + predictions=model_outputs.predictions) def _predict_ops(self, features): """Add ops for prediction to the graph.""" @@ -210,12 +211,12 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope(self._name, "head"): - if labels: + if labels is not None and labels != {}: # for better error messages. raise ValueError( - "The model received a `labels` dictionary, which is " - "not supported. Pass '{}' and '{}' as " - "features.".format(feature_keys.TrainEvalFeatures.TIMES, - feature_keys.TrainEvalFeatures.VALUES)) + "The model received a `labels`, which is not supported. " + "Pass '{}' and '{}' as features.".format( + feature_keys.TrainEvalFeatures.TIMES, + feature_keys.TrainEvalFeatures.VALUES)) del labels features = { name: self._convert_feature_to_tensor(name=name, value=value) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index c606db76a668235ab6a837159b9dec072b5fd801..ed8f29c321719e552c25f4d2183fdf4eb282e4b7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy import six +from tensorflow.contrib.estimator.python.estimator import extenders from tensorflow.contrib.timeseries.examples import lstm as lstm_example from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import feature_keys @@ -35,6 +36,7 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import variables @@ -53,9 +55,12 @@ class HeadTest(test.TestCase): model_fn = _stub_model_fn() for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL, estimator_lib.ModeKeys.PREDICT]: - with self.assertRaisesRegexp(ValueError, "labels"): + with self.assertRaisesRegexp(ValueError, "received a `labels`"): model_fn(features={}, labels={"a": "b"}, mode=mode) + with self.assertRaisesRegexp(ValueError, "received a `labels`"): + model_fn(features={}, labels=array_ops.zeros([]), mode=mode) + def test_unknown_mode(self): model_fn = _stub_model_fn() with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"): @@ -128,6 +133,44 @@ class EvaluationMetricsTests(test.TestCase): coordinator.request_stop() coordinator.join() + def test_custom_metrics(self): + """Tests that the custom metrics can be applied to the estimator.""" + model_dir = self.get_temp_dir() + estimator = ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel(num_features=1, num_units=4), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + model_dir=model_dir) + + def input_fn(): + return { + feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]], + feature_keys.TrainEvalFeatures.VALUES: + numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]]) + } + + def metrics_fn(predictions, features): + # checking that the inputs are properly passed. + predict = predictions["mean"] + target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0] + return { + "plain_boring_metric386": + (math_ops.reduce_mean(math_ops.abs(predict - target)), + control_flow_ops.no_op()), + "fun_metric101": (math_ops.reduce_sum(predict + target), + control_flow_ops.no_op()), + } + + # Evaluation without training is enough for testing custom metrics. + estimator = extenders.add_metrics(estimator, metrics_fn) + evaluation = estimator.evaluate(input_fn, steps=1) + self.assertIn("plain_boring_metric386", evaluation) + self.assertIn("fun_metric101", evaluation) + # The values are deterministic because of fixed tf_random_seed. + # However if they become flaky, remove such exacts comparisons. + self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380) + self.assertAllClose(evaluation["fun_metric101"], 10.435442) + class _StubModel(object): num_features = 3 diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index f84ff1bfe9b014733205a8e51b43f79c63b227cb..16696793bc2dab977a3dbbfa338e33e5771d0699 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -181,6 +181,7 @@ py_library( ":datasets", ":profiler", ":tpu_py", + "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index f632c953c85fcc335410c10db785265af9d8ddf3..15a2bb17a93212afe9ce5604a28d9dba5825f7d4 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -53,10 +53,10 @@ REGISTER_OP("TPUReplicatedInput") nullptr; for (int i = c->num_inputs() - 1; i >= 0; --i) { if (shapes_and_types) { - if (!c->MergeInputHandleShapesAndTypes(i, *shapes_and_types)) { - return errors::InvalidArgument( - "Incompatible resource shapes for replicated TPU input."); - } + // The return value of MergeInputHandleShapesAndTypes indicates + // the shape was refined, not that there was an error. + // TODO(phawkins): there seems to be no way to discover errors. + (void)c->MergeInputHandleShapesAndTypes(i, *shapes_and_types); } else { shapes_and_types = c->input_handle_shapes_and_types(i); } diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 7f1d25732e21b5dea4e605f6caa141ca9d3d02c6..7a5d01cca42351f6d4d8b41d43756560ce7874d3 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -17,12 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl import flags - import os import subprocess import sys - +from absl import flags +from distutils.version import LooseVersion import tensorflow as tf # Cloud TPU Cluster Resolvers @@ -35,9 +34,9 @@ flags.DEFINE_string( None, help='GCE zone where the Cloud TPU is located in. If not specified, we ' 'will attempt to automatically detect the GCE project from metadata.') -flags.DEFINE_string('tpu', None, - 'Name of the Cloud TPU for Cluster Resolvers. You must ' - 'specify either this flag or --service_addr.') +flags.DEFINE_string( + 'tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must ' + 'specify either this flag or --service_addr.') # Tool specific parameters flags.DEFINE_string( @@ -48,13 +47,13 @@ flags.DEFINE_string( ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or ' '--service_addr to profile a subset of tpu nodes. You can also use only' '--tpu and leave this flag unspecified to profile all the tpus.') -flags.DEFINE_string('logdir', None, - 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' - 'gs://tb_bucket') +flags.DEFINE_string( + 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' + 'gs://tb_bucket') flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') -flags.DEFINE_integer('num_tracing_attempts', 3, - 'Automatically retry N times when no trace ' - 'event is collected.') +flags.DEFINE_integer( + 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' + 'event is collected.') flags.DEFINE_boolean('include_dataset_ops', True, 'Set to false to profile longer TPU ' 'device traces.') @@ -63,18 +62,24 @@ FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' JOB_NAME = 'worker' + def get_workers_list(cluster_resolver): cluster_spec = cluster_resolver.cluster_spec() task_indices = cluster_spec.task_indices(JOB_NAME) - workers_list = [cluster_spec.task_address(JOB_NAME, i).split(':')[0] - for i in task_indices] + workers_list = [ + cluster_spec.task_address(JOB_NAME, i).split(':')[0] for i in task_indices + ] return ','.join(workers_list) + def run_main(): tf.app.run(main) + def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) + tf_version = tf.__version__ + print('TensorFlow version %s detected' % tf_version) if FLAGS.service_addr is None and FLAGS.tpu is None: sys.exit('You must specify either --service_addr or --tpu.') @@ -88,17 +93,19 @@ def main(unused_argv=None): else: tpu_cluster_resolver = ( tf.contrib.cluster_resolver.TPUClusterResolver( - [FLAGS.tpu], - zone=FLAGS.tpu_zone, - project=FLAGS.gcp_project)) + [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) service_addr = tpu_cluster_resolver.get_master() service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') - workers_list = "" - if FLAGS.workers_list is not None: - workers_list = FLAGS.workers_list - elif tpu_cluster_resolver is not None: - workers_list = get_workers_list(tpu_cluster_resolver) + workers_list = '' + if LooseVersion(tf_version) < LooseVersion('1.9'): + tf.logging.warn('Attempt to profile with legacy support under TensorFlow ' + 'version %s' % tf_version) + else: + if FLAGS.workers_list is not None: + workers_list = FLAGS.workers_list + elif tpu_cluster_resolver is not None: + workers_list = get_workers_list(tpu_cluster_resolver) if not FLAGS.logdir: sys.exit('logdir must be provided.') diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index f97a972f01a3ba5582df3675439aa962886f796e..19f088f8b862ce7b114490151f2b6a8c260b8580 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.7.0' +_VERSION = '1.9.0' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h index bd9ba6697edd9ef14dd3af0d2c9b77df9ec6917a..1bf49966d12db83f1e6904f8c00453bba278847c 100644 --- a/tensorflow/contrib/tpu/profiler/version.h +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ -#define TPU_PROFILER_VERSION "1.7.0" +#define TPU_PROFILER_VERSION "1.9.0" #endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index 7ecb36852c53bb74d70ed0f8c70ca1ce860a037a..26016f47dfb36990fd73267c70619878ac3450e5 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -2,7 +2,12 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", + "tf_proto_library_py", +) tf_proto_library( name = "tpu_embedding_config_proto", @@ -22,12 +27,14 @@ tf_proto_library( visibility = ["//visibility:public"], ) -tf_proto_library( +tf_proto_library_py( name = "compilation_result_proto", srcs = [ "compilation_result.proto", ], - cc_api_version = 2, - protodeps = ["//tensorflow/core:protos_all"], + protodeps = tf_additional_all_protos() + [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto index cf52897de3d0fefa55e68a6b889ae9af7b45864a..88585a5bd10fc28aa34bb0de72de970e21b2adb2 100644 --- a/tensorflow/contrib/tpu/proto/compilation_result.proto +++ b/tensorflow/contrib/tpu/proto/compilation_result.proto @@ -3,6 +3,7 @@ syntax = "proto3"; option cc_enable_arenas = true; package tensorflow.tpu; +import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/core/lib/core/error_codes.proto"; // Describes the result of a TPU compilation. @@ -10,4 +11,7 @@ message CompilationResultProto { // The error message, if any, returned during compilation. error.Code status_code = 1; string status_error_message = 2; + + // HLO proto. + repeated xla.HloProto hlo_protos = 3; } diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index f1a11fa6548b87d6222a97c72b8db5442c8ef774..754154438235f4c5e9e8db996acc8d843ab18431 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -19,15 +19,16 @@ To use, wrap your model with the `keras_support.tpu_model` function. Example usage: ``` -# Must activate before building TPU models -keras_support.setup_tpu_session(master_address) - image = tf.keras.layers.Input(shape=(28, 28, 3), name='image') c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image) flattened = tf.keras.layers.Flatten()(c1) logits = tf.keras.layers.Dense(10, activation='softmax')(flattened) model = tf.keras.Model(inputs=[image], outputs=[logits]) -model = keras_support.tpu_model(model) + +strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8) +model = keras_support.tpu_model(model, + strategy=strategy, + tpu_name_or_address=tpu_name) # Only TF optimizers are currently supported. model.compile(optimizer=tf.train.AdamOptimizer(), ...) @@ -35,9 +36,6 @@ model.compile(optimizer=tf.train.AdamOptimizer(), ...) # `images` and `labels` should be Numpy arrays. Support for tensor input # (e.g. datasets) is planned. model.fit(images, labels) - -# Invoke before shutting down -keras_support.shutdown_tpu_session() ``` """ @@ -48,9 +46,15 @@ from __future__ import division from __future__ import print_function import collections +import contextlib import re +import sys import time +import numpy as np + +from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops @@ -62,14 +66,17 @@ from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K -from tensorflow.python.keras import layers from tensorflow.python.keras import models from tensorflow.python.keras import optimizers as keras_optimizers +from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.layers import embeddings from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name + class TPUEmbedding(embeddings.Embedding): """TPU compatible embedding layer. @@ -93,10 +100,9 @@ class TPUEmbedding(embeddings.Embedding): class TPUModelOp( - collections.namedtuple( - 'TPUModelOp', - ['compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', - 'outfeed_op'])): + collections.namedtuple('TPUModelOp', [ + 'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op' + ])): pass @@ -105,13 +111,69 @@ def _valid_name(tensor_name): return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name) -def _replicated_optimizer(opt, num_replicas): +def _replicated_optimizer(opt): """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" - if num_replicas == 1: - return opt return keras_optimizers.TFOptimizer( - optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer) - ) + optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer)) + + +class TPURewriteContext(object): + """Prepare the environment for a Keras model during `tpu.rewrite`. + + This overrides the default placeholder behaviour to instead refer to a preset + input mapping. Placeholders are unsupported in TPU compiled code, and must + be replaced with explicit inputs or values from the infeed queue. + + Instead of explicitly threading inputs all the way through the Keras codebase, + we override the behavior of the placeholder while compiling and inject the + Tensors from the infeed in place of the placeholder. + + Similarly, as we compile a new sub-graph for each unique shape and execution + mode, we need to override the behavior of an embedded `name_scope` call in + the base Keras layer code. This allows us to re-use the same weights across + many compiles and share a single session/graph. + """ + + def __init__(self, input_map): + self._input_map = input_map + self._default_placeholder = None + self._default_name_scope = None + + def __enter__(self): + + def _placeholder(dtype, shape=None, name=None): # pylint: disable=unused-argument + logging.info('Remapping placeholder for %s', name) + if name in self._input_map: + return self._input_map[name] + else: + logging.info('Default: %s', name) + return self._default_placeholder(dtype, shape, name) + + def _name_scope(name, default_name=None, values=None): + caller_frame = sys._getframe().f_back + caller_obj = caller_frame.f_locals.get('self') + if (caller_obj is not None and + isinstance(caller_obj, base_layer.Layer) and name is not None): + logging.info('Intercepted name_scope: %s', caller_obj) + return variable_scope.variable_scope( + name, default_name, values, reuse=variable_scope.AUTO_REUSE) + + return self._default_name_scope(name, default_name, values) + + self._default_placeholder = array_ops.placeholder + self._default_name_scope = ops.name_scope + self._default_make_variable = base_layer.make_variable + + array_ops.placeholder = _placeholder + ops.name_scope = _name_scope + base_layer.make_variable = variable_scope.get_variable + logging.info('Overriding default placeholder.') + return + + def __exit__(self, exc_type, exc_val, exc_tb): + array_ops.placeholder = self._default_placeholder + ops.name_scope = self._default_name_scope + base_layer.make_variable = self._default_make_variable class TPUFunction(object): @@ -126,19 +188,18 @@ class TPUFunction(object): instead of being injected as `feed_dict` items or fetches. """ - def __init__(self, model, execution_mode, num_replicas=1): + def __init__(self, model, execution_mode, strategy): self.model = model self.execution_mode = execution_mode + self._strategy = strategy self._compilation_cache = {} - self.num_replicas = num_replicas + self._cloned_model = None def _specialize_model(self, input_specs): """Specialize `self.model` (a Keras model) for the given input shapes.""" # Re-create our input and output layers inside our subgraph. They will be # attached to the true computation when we clone our model in `tpu_fn`. - K.set_learning_phase( - self.execution_mode == model_fn_lib.ModeKeys.TRAIN - ) + K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN) # functools.partial and callable objects are not supported by tpu.rewrite def _model_fn(): @@ -164,23 +225,22 @@ class TPUFunction(object): infeed_tensors)) tpu_targets = [] - tpu_inputs = [] + tpu_input_map = {} # Sort infeed outputs into inputs and labels for calling our Keras model. for tensor, layer in zip(infeed_tensors, infeed_layers): if layer in self.model._input_layers: - tpu_inputs.append(layers.Input(name=layer.name, tensor=tensor)) + tpu_input_map[layer.name] = tensor if layer in self.model._output_layers: tpu_targets.append(tensor) - # Call our model with our infeed inputs (re-using the weights). - model_outputs = self.model(tpu_inputs) - child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs) + # Clone our CPU model, running within the TPU device context. + with TPURewriteContext(tpu_input_map): + self._cloned_model = models.clone_model(self.model) if is_training or is_test: - child_model.compile( - optimizer=_replicated_optimizer(self.model.optimizer, - self.num_replicas), + self._cloned_model.compile( + optimizer=_replicated_optimizer(self.model.optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, metrics=self.model.metrics, @@ -190,37 +250,37 @@ class TPUFunction(object): # Compute our outfeed depending on the execution mode if is_training: - child_model._make_train_function() + self._cloned_model._make_train_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.train_function.outputs + for tensor in self._cloned_model.train_function.outputs ] return [ - child_model.train_function.updates_op, + self._cloned_model.train_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - child_model.train_function.outputs, + self._cloned_model.train_function.outputs, name='outfeed-enqueue-train') ] elif is_test: - child_model._make_test_function() + self._cloned_model._make_test_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.test_function.outputs + for tensor in self._cloned_model.test_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - child_model.test_function.outputs, + self._cloned_model.test_function.outputs, name='outfeed-enqueue-test') ] elif is_predict: - child_model._make_predict_function() + self._cloned_model._make_predict_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.predict_function.outputs + for tensor in self._cloned_model.predict_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - child_model.predict_function.outputs, + self._cloned_model.predict_function.outputs, name='outfeed-enqueue-predict', ) ] @@ -235,7 +295,7 @@ class TPUFunction(object): # `execute op` replicates `_model_fn` `num_replicas` times, with each shard # running on a different logical core. compile_op, execute_op = tpu.split_compile_and_replicate( - _model_fn, inputs=[[]] * self.num_replicas) + _model_fn, inputs=[[]] * self._strategy.num_towers) # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. @@ -243,7 +303,7 @@ class TPUFunction(object): outfeed_op = [] shard_infeed_tensors = [] - for shard_id in range(self.num_replicas): + for shard_id in range(self._strategy.num_towers): with ops.device('/device:TPU:%d' % shard_id): infeed_tensors = [] for spec in input_specs: @@ -254,32 +314,35 @@ class TPUFunction(object): name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) shard_infeed_tensors.append(infeed_tensors) - infeed_op.append(tpu_ops.infeed_enqueue_tuple( - infeed_tensors, [spec.shape for spec in input_specs], - name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) + infeed_op.append( + tpu_ops.infeed_enqueue_tuple( + infeed_tensors, [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) - outfeed_op.extend(tpu_ops.outfeed_dequeue_tuple( - dtypes=[spec.dtype for spec in self._outfeed_spec], - shapes=[spec.shape for spec in self._outfeed_spec], - name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id))) + outfeed_op.extend( + tpu_ops.outfeed_dequeue_tuple( + dtypes=[spec.dtype for spec in self._outfeed_spec], + shapes=[spec.shape for spec in self._outfeed_spec], + name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id))) return TPUModelOp( - compile_op, execute_op, infeed_tensors=shard_infeed_tensors, - infeed_op=infeed_op, outfeed_op=outfeed_op) + compile_op, + execute_op, + infeed_tensors=shard_infeed_tensors, + infeed_op=infeed_op, + outfeed_op=outfeed_op) def _test_model_compiles(self, tpu_model_ops): """Verifies that the given TPUModelOp can be compiled via XLA.""" - session = K.get_session() - logging.info('Started compiling') start_time = time.clock() - result = session.run(tpu_model_ops.compile_op) + result = K.get_session().run(tpu_model_ops.compile_op) proto = tpu_compilation_result.CompilationResultProto() proto.ParseFromString(result) if proto.status_error_message: - raise RuntimeError( - 'Compilation failed: {}'.format(proto.status_error_message)) + raise RuntimeError('Compilation failed: {}'.format( + proto.status_error_message)) end_time = time.clock() logging.info('Finished compiling. Time elapsed: %s secs', @@ -296,17 +359,20 @@ class TPUFunction(object): Returns: List of lists containing the input to feed to each TPU shard. """ - if self.num_replicas == 1: + if self._strategy.num_towers == 1: return [inputs] batch_size = inputs[0].shape[0] - assert batch_size % self.num_replicas == 0, ( - 'batch_size must be divisible by num_replicas') - shard_size = batch_size // self.num_replicas + assert batch_size % self._strategy.num_towers == 0, ( + 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % + (batch_size, self._strategy.num_towers) + ) + shard_size = batch_size // self._strategy.num_towers input_list = [] - for index in range(self.num_replicas): - shard_inputs = [x[index * shard_size:(index + 1) * shard_size] - for x in inputs] + for index in range(self._strategy.num_towers): + shard_inputs = [ + x[index * shard_size:(index + 1) * shard_size] for x in inputs + ] input_list.append(shard_inputs) return input_list @@ -343,12 +409,15 @@ class TPUFunction(object): shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs]) if shape_key not in self._compilation_cache: - logging.info('New input shapes; (re-)compiling: mode=%s, %s', - self.execution_mode, input_specs) - new_tpu_model_ops = self._specialize_model(input_specs) - self._compilation_cache[shape_key] = new_tpu_model_ops - self._test_model_compiles(new_tpu_model_ops) - + with self.model.tpu_session(): + logging.info('New input shapes; (re-)compiling: mode=%s, %s', + self.execution_mode, input_specs) + new_tpu_model_ops = self._specialize_model(input_specs) + self._compilation_cache[shape_key] = new_tpu_model_ops + self._test_model_compiles(new_tpu_model_ops) + + # Initialize our TPU weights on the first compile. + self.model._initialize_weights(self._cloned_model) tpu_model_ops = self._compilation_cache[shape_key] infeed_dict = {} @@ -357,58 +426,83 @@ class TPUFunction(object): for tensor, value in zip(infeed_tensors, inputs): infeed_dict[tensor] = value - session = K.get_session() - _, _, outfeed_outputs = session.run([ - tpu_model_ops.infeed_op, tpu_model_ops.execute_op, - tpu_model_ops.outfeed_op - ], infeed_dict) + with self.model.tpu_session() as session: + _, _, outfeed_outputs = session.run([ + tpu_model_ops.infeed_op, tpu_model_ops.execute_op, + tpu_model_ops.outfeed_op + ], infeed_dict) # TODO(xiejw): Decide how to reduce outputs, or just discard all but first. - return outfeed_outputs[:len(outfeed_outputs) // self.num_replicas] - - -@experimental -def setup_tpu_session(master): - """Initializes and returns a Keras/TF session connected the TPU `master`.""" - session = tf_session.Session( - target=master, config=config_pb2.ConfigProto(isolate_session_state=True)) - K.set_session(session) - K.get_session().run(tpu.initialize_system()) - return session - - -@experimental -def shutdown_tpu_session(session=None): - """Shutdown the TPU attached to session. + if self.execution_mode == model_fn_lib.ModeKeys.PREDICT: + outputs = [[]] * len(self._outfeed_spec) + outputs_per_replica = len(self._outfeed_spec) - This should be called to cleanly shut down the TPU system before the client - exits. - - Args: - session: Session to shutdown, or None to use the default session. - - Returns: - - """ - if session is None: - session = K.get_session() + for i in range(self._strategy.num_towers): + output_group = outfeed_outputs[ + i * outputs_per_replica:(i+1) * outputs_per_replica + ] + for j in range(outputs_per_replica): + outputs[j].append(output_group[j]) - session.run(tpu.shutdown_system()) + return [np.concatenate(group) for group in outputs] + else: + return outfeed_outputs[:len(outfeed_outputs) // self._strategy.num_towers] class KerasTPUModel(models.Model): """TPU compatible Keras model wrapper.""" - def __init__(self, inputs, outputs, name, replicas=1): + def __init__(self, cpu_model, tpu_name_or_address, strategy): super(models.Model, self).__init__( # pylint: disable=bad-super-call - inputs=inputs, - outputs=outputs, - name=name, + inputs=cpu_model.inputs, + outputs=cpu_model.outputs, + name=cpu_model.name, ) + self.predict_function = None self.test_function = None self.train_function = None - self.replicas = replicas + self._strategy = strategy + + self._tpu_name_or_address = tpu_name_or_address + self._cpu_model = cpu_model + self._tpu_model = None + self._tpu_weights_initialized = False + self._graph = ops.Graph() + + cluster_resolver = tpu_cluster_resolver.TPUClusterResolver( + tpu_name_or_address) + cluster_spec = cluster_resolver.cluster_spec() + self._session = tf_session.Session( + graph=self._graph, + target=cluster_resolver.master(), + config=config_pb2.ConfigProto(isolate_session_state=True)) + + if cluster_spec: + self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + + with self._graph.as_default(): + self._session.run(tpu.initialize_system()) + + # If the input CPU model has already been compiled, compile our TPU model + # immediately. + if self._cpu_model.optimizer: + self.compile( + self._cpu_model.optimizer, + self._cpu_model.loss, + self._cpu_model.metrics, + self._cpu_model.loss_weights, + self._cpu_model.sample_weight_mode, + self._cpu_model.weighted_metrics, + self._cpu_model.target_tensors, + ) + + def get_config(self): + return { + 'cpu_model': self._cpu_model, + 'tpu_name_or_address': self._tpu_name_or_address, + 'strategy': self._strategy, + } def compile(self, optimizer, @@ -430,6 +524,11 @@ class KerasTPUModel(models.Model): sample_weight_mode, weighted_metrics, target_tensors, **kwargs) + if not self._cpu_model.optimizer: + self._cpu_model.compile(optimizer, loss, metrics, loss_weights, + sample_weight_mode, weighted_metrics, + target_tensors, **kwargs) + # Keras optimizers are not compatible with TPU rewrite if not isinstance(self.optimizer, keras_optimizers.TFOptimizer): raise ValueError( @@ -437,37 +536,90 @@ class KerasTPUModel(models.Model): def _make_train_function(self): if not self.train_function: - self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN, - num_replicas=self.replicas) + self.train_function = TPUFunction( + self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy) return self.train_function def _make_test_function(self): if not self.test_function: - self.test_function = TPUFunction(self, model_fn_lib.ModeKeys.EVAL) + self.test_function = TPUFunction( + self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy) return self.test_function def _make_predict_function(self): if not self.predict_function: - self.predict_function = TPUFunction(self, model_fn_lib.ModeKeys.PREDICT) + self.predict_function = TPUFunction( + self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy) return self.predict_function - def cpu_model(self): - cpu_model = models.Model( - inputs=self.inputs, - outputs=self.outputs, - name=self.name, - ) + def _initialize_weights(self, cloned_model): + """Initialize TPU weights. - if self.optimizer: - cpu_model.compile( - optimizer=self.optimizer, - loss=self.loss, - metrics=self.metrics, - loss_weights=self.loss_weights, - ) + This is called on the first compile of the TPU model (first call to + fit/predict/evaluate). - return cpu_model + Args: + cloned_model: `keras.Model`, TPU model to initialize. + """ + if self._tpu_weights_initialized: + return + + self._tpu_model = cloned_model + self._tpu_weights_initialized = True + + weights = self._cpu_model.get_weights() + with self.tpu_session(): + logging.info('Setting weights on TPU model.') + cloned_model.set_weights(weights) + + def sync_to_cpu(self): + """Copy weights from the CPU, returning a synchronized CPU model.""" + if self._tpu_weights_initialized: + with self.tpu_session(): + logging.info('Copying TPU weights to the CPU') + tpu_weights = self._tpu_model.get_weights() + + self._cpu_model.set_weights(tpu_weights) + + return self._cpu_model + + def get_weights(self): + return self.sync_to_cpu().get_weights() + + def save_weights(self, *args, **kw): + return self.sync_to_cpu().save_weights(*args, **kw) + + def save(self, *args, **kw): + return self.sync_to_cpu().save(*args, **kw) + + def set_weights(self, weights): + # We may not have a TPU model available if we haven't run fit/predict, so + # we can't directly set the TPU weights here. + # Instead, reset CPU model weights and force TPU re-initialization at the + # next call. + self._cpu_model.set_weights(weights) + self._tpu_weights_initialized = False + + @contextlib.contextmanager + def tpu_session(self): + """Yields a TPU session and sets it as the default Keras session.""" + with self._graph.as_default(): + default_session = K.get_session() + # N.B. We have to call `K.set_session()` AND set our session as the + # TF default. `K.get_session()` surprisingly does not return the value + # supplied by K.set_session otherwise. + K.set_session(self._session) + with self._session.as_default(): + yield self._session + K.set_session(default_session) + + def shutdown(self): + logging.info('Shutting down TPU session.') + with self.tpu_session() as session: + session.run(tpu.shutdown_system()) + + self._session.close() def _validate_shapes(model): @@ -504,26 +656,8 @@ Output shape: %(output_shape)s @experimental -def tpu_model(model, replicas=None): - """Runs a model on TPU(s). - - Usage: - ``` - a = Input(shape=(32,)) - b = Dense(32)(a) - model = Model(inputs=a, outputs=b) - - model = keras_support.tpu_model(model) - model.compile( - optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), - ...) - ``` - - If `replicas` is set, replicates the model computation on all TPU cores. The - model computation is replicated `num_replicas` times; each shard will run on a - different TPU core. - - Limitation: Currently, replication is only supported for training. +def tpu_model(model, tpu_name_or_address=None, strategy=None): + """Copy `model` along with weights to the TPU. Returns a TPU model. Usage: ``` @@ -531,17 +665,24 @@ def tpu_model(model, replicas=None): b = Dense(32)(a) model = Model(inputs=a, outputs=b) - model = keras_support.tpu_model(model, replicas=2) + # If `num_cores_per_host` is greater than one, batch parallelism will be used + # to run on multiple TPU cores. + strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8) + model = keras_support.tpu_model(model, strategy) model.compile( optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), ...) + model.shutdown() ``` Args: model: A `KerasTPUModel`. - replicas: (Optional) Int, number of TPU cores which to create model - replicas. If `None`, the model runs on single core only, i.e., no - replication. + tpu_name_or_address: A string that is either the name of the Cloud TPU, + the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the + Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will + examine the environment to determine a potential Cloud TPU to use. + strategy: `TPUDistributionStrategy`. The strategy to use for replicating + model across multiple TPU cores. Returns: A new `KerasTPUModel` instance. @@ -550,7 +691,9 @@ def tpu_model(model, replicas=None): # TODO(xiejw): Validate TPU model. TPUModel only? # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset? # TODO(xiejw): Adds reduction option. - replicas = 1 if replicas is None else replicas + if strategy is None: + strategy = TPUDistributionStrategy(num_cores_per_host=1) return KerasTPUModel( - inputs=model.inputs, outputs=model.outputs, name=model.name, - replicas=replicas) + cpu_model=model, + tpu_name_or_address=tpu_name_or_address, + strategy=strategy) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index dc473c5846aafc5a92756dfb8259f7f8dc14b98d..6a64893d9abcd64360554ab00502cdf360b820b6 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -227,19 +227,26 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): class FakeOp(object): """A helper class to determine the current device. - Supports only the device set/get methods needed to run the + Supports only the type and device set/get methods needed to run the graph's _apply_device_function method. """ def __init__(self): self._device = "" + @property + def type(self): + return "FakeOp" + @property def device(self): return self._device def _set_device(self, device): - self._device = device.to_string() + if isinstance(device, pydev.DeviceSpec): + self._device = device.to_string() + else: + self._device = device if self._outside_compilation_cluster: raise NotImplementedError("Cannot nest outside_compilation clusters") diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 5b9aeaa8797b92b4cc596744812f440607054dce..aec59f3885ca7a2046c24ce5b94917ad6c3693e7 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -92,6 +92,19 @@ class TPUContext(object): """ return self._internal_ctx.num_replicas + @property + def num_hosts(self): + """The number of hosts for the TPU system.""" + return self._internal_ctx.num_hosts + + @property + def num_of_replicas_per_host(self): + """The number of replicas for each host.""" + if self._internal_ctx.model_parallelism_enabled: + raise ValueError( + 'num_of_replicas_per_host is not supported for model_parallelism') + return self._internal_ctx.num_of_replicas_per_host + def device_for_replica(self, replica_id): """Returns the tuple of (CPU device and device ordinal) for replica. @@ -384,9 +397,7 @@ class _InternalTPUContext(object): # On TPU if self.is_input_sharded_per_core() or ( self.is_input_per_host_with_iterators()): - # We prohibit per core input sharding for the model parallelism case, - # therefore it is safe to use num_cores here. - return global_batch_size // self.num_cores + return global_batch_size // self.num_replicas else: return global_batch_size // self.num_hosts @@ -484,25 +495,27 @@ class _InternalTPUContext(object): return _placement_function - @property - def tpu_ordinal_function(self): + def tpu_ordinal_function(self, host_id): """Returns the TPU ordinal fn.""" - def _tpu_ordinal_function(index): + def _tpu_ordinal_function(shard_index_in_host): """Return the TPU ordinal associated with a shard. Required because the enqueue ops are placed on CPU. Args: - index: the shard index + shard_index_in_host: the shard index Returns: The ordinal of the TPU device the shard's infeed should be placed on. """ if self.model_parallelism_enabled: - return self.device_assignment.tpu_ordinal(replica=index) + # We put both enqueue/dequeue ops at tpu.core(0) in each replica. + replica = self.device_assignment.lookup_replicas( + host_id, (0, 0, 0))[shard_index_in_host] + return self.device_assignment.tpu_ordinal(replica=replica) else: - return index % self.num_of_cores_per_host + return shard_index_in_host % self.num_of_cores_per_host return _tpu_ordinal_function diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index e94bd78833f6cbe9adb1b6ca3f29a88bd8a53f64..14e025973ea99b387831637f905bd9f7c5f7e569 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -81,12 +81,17 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CTX_KEY = 'context' +_USE_TPU_KEY = 'use_tpu' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +# Ideally _USE_TPU_KEY should be reserved as well. However there are already +# models that make use of this key, thus it can not be reserved now to prevent +# breakage. In the long run, we would like to mitigate this by migrating models +# off of using _USE_TPU_KEY. _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -664,6 +669,7 @@ def generate_per_core_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, host_device, host_id): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A fn returns enqueue_ops.""" @@ -699,7 +705,7 @@ def generate_per_core_enqueue_ops_fn_for_host( per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue @@ -734,19 +740,7 @@ def generate_per_host_enqueue_ops_fn_for_host( if is_dataset: hooks.append(inputs.dataset_initializer_hook()) - # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the - # _InternalTPUContext.tpu_ordinal_function. We should either introduce another - # abstraction or a different helper method. - def _tpu_ordinal_function_impl(shard_index_in_host): - # We put both enqueue/dequeue op at tpu.core(0) in each replica. - replica = ctx.device_assignment.lookup_replicas( - host_id, (0, 0, 0))[shard_index_in_host] - return ctx.device_assignment.tpu_ordinal(replica=replica) - - if ctx.model_parallelism_enabled: - tpu_ordinal_function = _tpu_ordinal_function_impl - else: - tpu_ordinal_function = None + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A Fn returning the TPU infeed enqueue ops. @@ -782,7 +776,7 @@ def generate_per_host_enqueue_ops_fn_for_host( infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_tensor_list, placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function)) + tpu_ordinal_function=tpu_ordinal_function_impl)) if signals is None: return per_host_enqueue_ops else: @@ -816,6 +810,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.') hooks.append(inputs.dataset_initializer_hook()) + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """Generates the per_host enqueue ops.""" @@ -846,7 +841,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset @@ -1146,7 +1141,7 @@ class _InputPipeline(object): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' 'converting your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/programmers_guide/datasets for ' + 'https://www.tensorflow.org/guide/datasets for ' 'instructions.') if _WRAP_INPUT_FN_INTO_WHILE_LOOP: raise RuntimeError(err_msg) @@ -1424,8 +1419,11 @@ class _ModelFnWrapper(object): if batch_size_for_model_fn is not None: _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) + running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) + _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) + estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu(is_export_mode) and + if (running_on_cpu and isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. @@ -2043,24 +2041,29 @@ class TPUEstimator(estimator_lib.Estimator): strip_default_attrs, save_variables=True, mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None): + export_tags=None, + check_variables=True): if mode != model_fn_lib.ModeKeys.PREDICT: raise NotImplementedError( 'TPUEstimator only handles mode PREDICT for export_savedmodel(); ' 'got {}.'.format(mode)) - super(TPUEstimator, self)._add_meta_graph_for_mode(builder, - input_receiver_fn_map, - checkpoint_path, - strip_default_attrs, - save_variables, - mode=mode) + (super(TPUEstimator, self). + _add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables, + mode=mode, + export_tags=export_tags, + check_variables=check_variables)) if self._export_to_tpu: input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode]} export_tags = [tag_constants.SERVING, tag_constants.TPU] mode = _REWRITE_FOR_INFERENCE_MODE + # See b/110052256 for why `check_variables` is `False`. (super(TPUEstimator, self). _add_meta_graph_for_mode(builder, input_receiver_fn_map, @@ -2068,7 +2071,8 @@ class TPUEstimator(estimator_lib.Estimator): strip_default_attrs, save_variables=False, mode=mode, - export_tags=export_tags)) + export_tags=export_tags, + check_variables=False)) def _call_model_fn(self, features, labels, mode, config): if mode == _REWRITE_FOR_INFERENCE_MODE: @@ -3115,7 +3119,7 @@ class _SignalsHelper(object): def __init__(self, signals): self._signal_keys = [] - for key in sorted(signals.iterkeys()): + for key in sorted(iter(signals.keys())): self._signal_keys.append(key) @property @@ -3127,7 +3131,7 @@ class _SignalsHelper(object): @staticmethod def as_tensor_list(signals): - return [signals[key] for key in sorted(signals.iterkeys())] + return [signals[key] for key in sorted(iter(signals.keys()))] def _verify_cross_hosts_transfer_size(tensor_dict, message): @@ -3153,7 +3157,7 @@ def _add_item_to_params(params, key, value): if isinstance(params, hparam.HParams): # For HParams, we need to use special API. if key in params: - params.key = value + params.set_hparam(key, value) else: params.add_hparam(key, value) else: diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 5de55b5f7f2a41ac6edd27e5a102e565f33df12c..76927e62e82d02de172a0851819716dc63180371 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -295,7 +295,7 @@ py_test( tags = ["notsan"], deps = [ ":training_py", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py index 0338f409a203c232e63e99534a8f6d6a43fa661e..df0a186f4f6963d7e874bb4ab74a8db7e10a52ee 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py @@ -19,7 +19,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b6b48a077cdafe12aeb1e4e0988493692c82eace..0e6bc03c0b4b5a31d52bf5ce1f5c472a9e47d446 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -89,6 +89,7 @@ load( "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", + "tf_features_nomodules_if_android", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -792,6 +793,7 @@ tf_cuda_library( "framework/graph_def_util.h", "framework/graph_to_functiondef.h", "framework/kernel_def_builder.h", + "framework/kernel_def_util.h", "framework/log_memory.h", "framework/lookup_interface.h", "framework/memory_types.h", @@ -901,6 +903,15 @@ cc_library( hdrs = ["util/ptr_util.h"], ) +cc_library( + name = "status_util", + hdrs = ["util/status_util.h"], + deps = [ + ":graph", + ":lib", + ], +) + cc_library( name = "reader_base", srcs = ["framework/reader_base.cc"], @@ -998,6 +1009,7 @@ tf_gen_op_libs( "nn_ops", "no_op", "parsing_ops", + "random_grad", "random_ops", "remote_fused_graph_ops", "resource_variable_ops", @@ -1196,6 +1208,7 @@ tf_cuda_library( hdrs = [ "common_runtime/device.h", "common_runtime/device_factory.h", + "common_runtime/function.h", "common_runtime/optimization_registry.h", "common_runtime/shape_refiner.h", "graph/algorithm.h", @@ -1910,6 +1923,7 @@ tf_proto_library_cc( srcs = ["protobuf/master_service.proto"], has_services = 1, cc_api_version = 2, + cc_grpc_version = 1, cc_stubby_versions = ["2"], protodeps = [":master_proto"], visibility = [ @@ -2339,6 +2353,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests + "framework/resource_var.h", "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", @@ -3369,10 +3384,12 @@ tf_cc_tests( "framework/bfloat16_test.cc", "framework/cancellation_test.cc", "framework/common_shape_fns_test.cc", + "framework/device_base_test.cc", "framework/function_test.cc", "framework/graph_def_util_test.cc", "framework/graph_to_functiondef_test.cc", "framework/kernel_def_builder_test.cc", + "framework/kernel_def_util_test.cc", "framework/memory_types_test.cc", "framework/node_def_builder_test.cc", "framework/node_def_util_test.cc", @@ -3397,6 +3414,7 @@ tf_cc_tests( "framework/variant_op_registry_test.cc", "framework/variant_test.cc", "graph/algorithm_test.cc", + "graph/control_flow_test.cc", "graph/edgeset_test.cc", "graph/graph_def_builder_test.cc", "graph/graph_partition_test.cc", @@ -3421,6 +3439,7 @@ tf_cc_tests( "util/semver_test.cc", "util/sparse/sparse_tensor_test.cc", "util/stat_summarizer_test.cc", + "util/status_util_test.cc", "util/tensor_format_test.cc", "util/tensor_slice_reader_test.cc", "util/tensor_slice_set_test.cc", @@ -3445,6 +3464,7 @@ tf_cc_tests( ":ops", ":protos_all_cc", ":protos_test_cc", + ":status_util", ":test", ":test_main", ":testlib", @@ -3903,13 +3923,13 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", srcs = ["common_runtime/direct_session_test.cc"], + args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory linkstatic = tf_kernel_tests_linkstatic(), deps = [ - ":core", ":core_cpu", ":core_cpu_internal", ":direct_session_internal", @@ -3922,6 +3942,7 @@ tf_cc_test( ":test", ":test_main", ":testlib", + "//third_party/eigen3", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", @@ -3935,8 +3956,7 @@ tf_cc_test( "//tensorflow/core/kernels:queue_ops", "//tensorflow/core/kernels:session_ops", "//tensorflow/core/kernels:variable_ops", - "//third_party/eigen3", - ], + ] + if_cuda([":cuda"]), ) # This is identical to :common_runtime_direct_session_test with the addition of diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 477a0b670e49f8aa4ee8c250d4957886eb865ed5..6149e5fca804fba54b6c77183e5c271c9b9d9a81 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -171,7 +171,7 @@ TEST_F(BaseApiTest, AllOpsAreInApiDef) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { continue; } - ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) + EXPECT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) << op.name() << " op does not have api_def_*.pbtxt file. " << "Please add api_def_" << op.name() << ".pbtxt file " << "under tensorflow/core/api_def/base_api/ directory."; diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesExampleDebugOutputs.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesExampleDebugOutputs.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..206fa3cc989c61b359d8c539fb02e1d95bf994a7 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesExampleDebugOutputs.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "BoostedTreesExampleDebugOutputs" + visibility: HIDDEN + in_arg { + name: "bucketized_features" + description: < + - classname: tfo-landing-row-item-code-block + code_block: | +
+        import tensorflow as tf
+        mnist = tf.keras.datasets.mnist
+
+        (x_train, y_train),(x_test, y_test) = mnist.load_data()
+        x_train, x_test = x_train / 255.0, x_test / 255.0
+
+        model = tf.keras.models.Sequential([
+          tf.keras.layers.Flatten(),
+          tf.keras.layers.Dense(512, activation=tf.nn.relu),
+          tf.keras.layers.Dropout(0.2),
+          tf.keras.layers.Dense(10, activation=tf.nn.softmax)
+        ])
+        model.compile(optimizer='adam',
+                      loss='sparse_categorical_crossentropy',
+                      metrics=['accuracy'])
+
+        model.fit(x_train, y_train, epochs=5)
+        model.evaluate(x_test, y_test)
+        
+ {% dynamic if request.tld != 'cn' %} + Run in a Notebook + {% dynamic endif %} + + - items: + - custom_html: > +
+

Research and experimentation

+
+

+ Eager execution provides an imperative, define-by-run interface for advanced operations. Write custom layers, forward passes, and training loops with auto‑differentiation. Start with + these notebooks, then read the eager execution guide. +

+
    +
  1. + {% dynamic if request.tld == 'cn' %} + Eager execution basics + {% dynamic else %} + Eager execution basics + {% dynamic endif %} +
  2. +
  3. + {% dynamic if request.tld == 'cn' %} + Automatic differentiation and gradient tapes + {% dynamic else %} + Automatic differentiation and gradient tapes + {% dynamic endif %} +
  4. +
  5. + {% dynamic if request.tld == 'cn' %} + Variables, models, and training + {% dynamic else %} + Variables, models, and training + {% dynamic endif %} +
  6. +
  7. + {% dynamic if request.tld == 'cn' %} + Custom layers + {% dynamic else %} + Custom layers + {% dynamic endif %} +
  8. +
  9. Custom training walkthrough
  10. +
  11. + {% dynamic if request.tld == 'cn' %} + Example: Neural machine translation w/ attention + {% dynamic else %} + Example: Neural machine translation w/ attention + {% dynamic endif %} +
  12. +
+
+ +
+ - custom_html: > +
+

ML at production scale

+
+

+ Estimators can train large models on multiple machines in a + production environment. Try the examples below and read the + Estimators guide. +

+
    +
  1. How to build a simple text classifier with TF-Hub
  2. +
  3. Classifying Higgs boson processes
  4. +
  5. Wide and deep learning using estimators
  6. +
+
+ +
+ + - description: > +

Google Colab: An easy way to learn and use TensorFlow

+

+ Colaboratory + is a Google research project created to help disseminate machine learning + education and research. It's a Jupyter notebook environment that requires + no setup to use and runs entirely in the cloud. + Read the blog post. +

+ + - description: > +

Build your first ML app

+

Create and deploy TensorFlow models on web and mobile.

+ background: grey + items: + - custom_html: > +
+ +

Web developers

+
+
+ TensorFlow.js is a WebGL accelerated, JavaScript library to train and + deploy ML models in the browser and for Node.js. +
+
+ - custom_html: > +
+ +

Mobile developers

+
+
+ TensorFlow Lite is lightweight solution for mobile and embedded devices. +
+
+ + - description: > +

Videos and updates

+

+ Subscribe to the TensorFlow + YouTube channel + and blog for + the latest videos and updates. +

+ items: + - description: > +

Get started with TensorFlow's High-Level APIs

+ youtube_id: tjsHSIG8I08 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=tjsHSIG8I08 + - description: > +

Eager execution

+ youtube_id: T8AW0fKP0Hs + background: grey + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=T8AW0fKP0Hs + - description: > +

tf.data: Fast, flexible, and easy-to-use input pipelines

+ youtube_id: uIcqeP7MFH0 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=uIcqeP7MFH0 diff --git a/tensorflow/docs_src/get_started/basic_classification.md b/tensorflow/docs_src/get_started/basic_classification.md new file mode 100644 index 0000000000000000000000000000000000000000..91bbd85b2442522ef34eba236bf5bab2fc8654a7 --- /dev/null +++ b/tensorflow/docs_src/get_started/basic_classification.md @@ -0,0 +1,3 @@ +# Basic Classification + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_classification.ipynb) diff --git a/tensorflow/docs_src/get_started/basic_regression.md b/tensorflow/docs_src/get_started/basic_regression.md new file mode 100644 index 0000000000000000000000000000000000000000..a535f22f5a41e7cb34cb8424b60d10d4ad43940e --- /dev/null +++ b/tensorflow/docs_src/get_started/basic_regression.md @@ -0,0 +1,3 @@ +# Basic Regression + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_regression.ipynb) diff --git a/tensorflow/docs_src/get_started/basic_text_classification.md b/tensorflow/docs_src/get_started/basic_text_classification.md new file mode 100644 index 0000000000000000000000000000000000000000..7c5d4f78968f94e4d5685a2dffe75ab649431e38 --- /dev/null +++ b/tensorflow/docs_src/get_started/basic_text_classification.md @@ -0,0 +1,3 @@ +# Basic Text Classification + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/basic_text_classification.ipynb) diff --git a/tensorflow/docs_src/get_started/eager.md b/tensorflow/docs_src/get_started/eager.md index bbb25e20c62f6a2eec78668250a0e748494797c5..ddf239485a5546e0566d742f19c5d5b7025b157b 100644 --- a/tensorflow/docs_src/get_started/eager.md +++ b/tensorflow/docs_src/get_started/eager.md @@ -1,3 +1,3 @@ -# Get Started with Eager Execution +# Custom Training Walkthrough [Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/r1.9.0/samples/core/get_started/eager.ipynb) diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md index 232d2f154703dc10320f9ee074c67d6e1a8ee850..bd2a80d9efee7f76111c6769b97d640af216c833 100644 --- a/tensorflow/docs_src/get_started/index.md +++ b/tensorflow/docs_src/get_started/index.md @@ -23,7 +23,7 @@ For more advanced users: * The @{$low_level_intro$Low Level Introduction} demonstrates how to use TensorFlow outside of the Estimator framework, for debugging and experimentation. - * The @{$programmers_guide$Programmer's Guide} details major + * The @{$guide$Programmer's Guide} details major TensorFlow components. * The @{$tutorials$Tutorials} provide walkthroughs of a variety of TensorFlow models. diff --git a/tensorflow/docs_src/get_started/leftnav_files b/tensorflow/docs_src/get_started/leftnav_files index e6cc8d565810683947e9cf9692e7cccb43916e66..99d2b2c3e1fa257fd03636eb44b2b65201ae311d 100644 --- a/tensorflow/docs_src/get_started/leftnav_files +++ b/tensorflow/docs_src/get_started/leftnav_files @@ -1,4 +1,10 @@ -index.md +### Learn and use ML +basic_classification.md: Basic classification +basic_text_classification.md: Text classification +basic_regression.md: Regression +overfit_and_underfit.md +save_and_restore_models.md +next_steps.md +### Research and experimentation eager.md -datasets_quickstart.md diff --git a/tensorflow/docs_src/get_started/next_steps.md b/tensorflow/docs_src/get_started/next_steps.md new file mode 100644 index 0000000000000000000000000000000000000000..01c9f7204a7ddae16bcbd9eb5702516a39f8ce4c --- /dev/null +++ b/tensorflow/docs_src/get_started/next_steps.md @@ -0,0 +1,36 @@ +# Next steps + +## Learn more about TensorFlow + +* The [TensorFlow Guide](/guide) includes usage guides for the + high-level APIs, as well as advanced TensorFlow operations. +* [Premade Estimators](/guide/premade_estimators) are designed to + get results out of the box. Use TensorFlow without building your own models. +* [TensorFlow.js](https://js.tensorflow.org/) allows web developers to train and + deploy ML models in the browser and using Node.js. +* [TFLite](/mobile/tflite) allows mobile developers to do inference efficiently + on mobile devices. +* [TensorFlow Serving](/serving) is an open-source project that can put + TensorFlow models in production quickly. +* The [ecosystem](/ecosystem) contains more projects, including + [Magenta](https://magenta.tensorflow.org/), [TFX](/tfx), + [Swift for TensorFlow](https://github.com/tensorflow/swift), and more. + +## Learn more about machine learning + +Recommended resources include: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/), + a course from Google that introduces machine learning concepts. +* [CS 20: Tensorflow for Deep Learning Research](http://web.stanford.edu/class/cs20si/), + notes from an intro course from Stanford. +* [CS231n: Convolutional Neural Networks for Visual Recognition](http://cs231n.stanford.edu/), + a course that teaches how convolutional networks work. +* [Machine Learning Recipes](https://www.youtube.com/watch?v=cKxRvEZd3Mw&list=PLOU2XLYxmsIIuiBfYad6rFYQU_jL2ryal), + a video series that introduces basic machine learning concepts with few prerequisites. +* [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python), + a book by Francois Chollet about the Keras API, as well as an excellent hands on intro to Deep Learning. +* [Hands-on Machine Learning with Scikit-Learn and TensorFlow](https://github.com/ageron/handson-ml), + a book by AurĆ©lien Geron's that is a clear getting-started guide to data science and deep learning. +* [Deep Learning](https://www.deeplearningbook.org/), a book by Ian Goodfellow et al. + that provides a technical dive into learning machine learning. diff --git a/tensorflow/docs_src/get_started/overfit_and_underfit.md b/tensorflow/docs_src/get_started/overfit_and_underfit.md new file mode 100644 index 0000000000000000000000000000000000000000..e5b5ae7b5a70f476c25cc7bb76572bf6433c289f --- /dev/null +++ b/tensorflow/docs_src/get_started/overfit_and_underfit.md @@ -0,0 +1,3 @@ +# Overfitting and Underfitting + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/overfit_and_underfit.ipynb) diff --git a/tensorflow/docs_src/get_started/save_and_restore_models.md b/tensorflow/docs_src/get_started/save_and_restore_models.md new file mode 100644 index 0000000000000000000000000000000000000000..44b377294562cf5a0c8139e88d0c7226506b32ba --- /dev/null +++ b/tensorflow/docs_src/get_started/save_and_restore_models.md @@ -0,0 +1,3 @@ +# Save and restore Models + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/save_and_restore_models.ipynb) diff --git a/tensorflow/docs_src/programmers_guide/checkpoints.md b/tensorflow/docs_src/guide/checkpoints.md similarity index 96% rename from tensorflow/docs_src/programmers_guide/checkpoints.md rename to tensorflow/docs_src/guide/checkpoints.md index 8dfd91e3c8368f4a649c5b5fa3947e97441ef390..dfb2626b8675ccc3db293498314fcc3e417bc1bd 100644 --- a/tensorflow/docs_src/programmers_guide/checkpoints.md +++ b/tensorflow/docs_src/guide/checkpoints.md @@ -8,9 +8,8 @@ Estimators. TensorFlow provides two model formats: * SavedModel, which is a format independent of the code that created the model. -This document focuses on checkpoints. For details on SavedModel, see the -@{$saved_model$Saving and Restoring} chapter of the -*TensorFlow Programmer's Guide*. +This document focuses on checkpoints. For details on `SavedModel`, see the +@{$saved_model$Saving and Restoring} guide. ## Sample code @@ -232,8 +231,7 @@ This separation will keep your checkpoints recoverable. Checkpoints provide an easy automatic mechanism for saving and restoring models created by Estimators. -See the @{$saved_model$Saving and Restoring} -chapter of the *TensorFlow Programmer's Guide* for details on: +See the @{$saved_model$Saving and Restoring} guide for details about: * Saving and restoring models using low-level TensorFlow APIs. * Exporting and importing models in the SavedModel format, which is a diff --git a/tensorflow/docs_src/programmers_guide/custom_estimators.md b/tensorflow/docs_src/guide/custom_estimators.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/custom_estimators.md rename to tensorflow/docs_src/guide/custom_estimators.md index fb20b35c128b5bdafbb88ccb19df05f6a73c9977..a63e2bafb362c660d9203c609e46cdffb7955342 100644 --- a/tensorflow/docs_src/programmers_guide/custom_estimators.md +++ b/tensorflow/docs_src/guide/custom_estimators.md @@ -362,10 +362,10 @@ model's loss. This is the that will be optimized. We can calculate the loss by calling @{tf.losses.sparse_softmax_cross_entropy}. -The value returned by this function will be lowest, approximately 0, -probability of the correct class (at index `label`) is near 1.0. The loss value -returned is progressively larger as the probability of the correct class -decreases. +The value returned by this function will be approximately 0 at lowest, +when the probability of the correct class (at index `label`) is near 1.0. +The loss value returned is progressively larger as the probability of the +correct class decreases. This function returns the average over the whole batch. diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/guide/datasets.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/datasets.md rename to tensorflow/docs_src/guide/datasets.md diff --git a/tensorflow/docs_src/get_started/datasets_quickstart.md b/tensorflow/docs_src/guide/datasets_for_estimators.md similarity index 97% rename from tensorflow/docs_src/get_started/datasets_quickstart.md rename to tensorflow/docs_src/guide/datasets_for_estimators.md index 020e40dd3b8f046f0144e3806468f58833f7b607..b04af78cd820f1b3506f62112f25dd8fdb73e76c 100644 --- a/tensorflow/docs_src/get_started/datasets_quickstart.md +++ b/tensorflow/docs_src/guide/datasets_for_estimators.md @@ -1,4 +1,4 @@ -# Datasets Quick Start +# Datasets for Estimators The @{tf.data} module contains a collection of classes that allows you to easily load data, manipulate it, and pipe it into your model. This document @@ -91,8 +91,8 @@ print(mnist_ds) ``` This will print the following line, showing the -@{$programmers_guide/tensors#shapes$shapes} and -@{$programmers_guide/tensors#data_types$types} of the items in +@{$guide/tensors#shapes$shapes} and +@{$guide/tensors#data_types$types} of the items in the dataset. Note that a `Dataset` does not know how many items it contains. ``` None @@ -128,7 +128,7 @@ print(dataset) Here we see that when a `Dataset` contains structured elements, the `shapes` and `types` of the `Dataset` take on the same structure. This dataset contains -dictionaries of @{$programmers_guide/tensors#rank$scalars}, all of type +dictionaries of @{$guide/tensors#rank$scalars}, all of type `tf.float64`. The first line of the iris `train_input_fn` uses the same functionality, but @@ -382,6 +382,6 @@ Estimator. Consider the following documents next: * The @{$low_level_intro#datasets$Low Level Introduction}, which demonstrates how to experiment directly with `tf.data.Datasets` using TensorFlow's low level APIs. -* @{$programmers_guide/datasets} which goes into great detail about additional +* @{$guide/datasets} which goes into great detail about additional functionality of `Datasets`. diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/guide/debugger.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/debugger.md rename to tensorflow/docs_src/guide/debugger.md index 49258c7b4a9051c050d762f62c8e4439cea4f198..dc4db58857c211f95bd7d5f2b3232e63f9877288 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/guide/debugger.md @@ -210,6 +210,7 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at | **`config`** | | **Set or show persistent TFDBG UI configuration.** | | | | `set` | Set the value of a config item: {`graph_recursion_depth`, `mouse_mode`}. | `config set graph_recursion_depth 3` | | | `show` | Show current persistent UI configuration. | `config show` | +| **`version`** | | **Print the version of TensorFlow and its key dependencies.** | `version` | | **`help`** | | **Print general help information** | `help` | | | `help ` | Print help for given command. | `help lt` | diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/guide/eager.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/eager.md rename to tensorflow/docs_src/guide/eager.md diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/guide/embedding.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/embedding.md rename to tensorflow/docs_src/guide/embedding.md diff --git a/tensorflow/docs_src/programmers_guide/estimators.md b/tensorflow/docs_src/guide/estimators.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/estimators.md rename to tensorflow/docs_src/guide/estimators.md index b13b47184d2b32fffb2390b0318fba8612d7826a..78b30c3040f646e4ae1bf97246666e8585e18057 100644 --- a/tensorflow/docs_src/programmers_guide/estimators.md +++ b/tensorflow/docs_src/guide/estimators.md @@ -81,7 +81,7 @@ of the following four steps: ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label - (See @{$programmers_guide/datasets} for full details.) + (See @{$guide/datasets} for full details.) 2. **Define the feature columns.** Each @{tf.feature_column} identifies a feature name, its type, and any input pre-processing. diff --git a/tensorflow/docs_src/programmers_guide/faq.md b/tensorflow/docs_src/guide/faq.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/faq.md rename to tensorflow/docs_src/guide/faq.md diff --git a/tensorflow/docs_src/programmers_guide/feature_columns.md b/tensorflow/docs_src/guide/feature_columns.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/feature_columns.md rename to tensorflow/docs_src/guide/feature_columns.md index 90f5c53a17f23200f238f6b0d171e1e225330e27..1013ec910c1ebf9b781a9e84b6f5f33bcaa73690 100644 --- a/tensorflow/docs_src/programmers_guide/feature_columns.md +++ b/tensorflow/docs_src/guide/feature_columns.md @@ -534,7 +534,7 @@ embedding_column = tf.feature_column.embedding_column( dimension=embedding_dimensions) ``` -@{$programmers_guide/embedding$Embeddings} is a significant topic within machine +@{$guide/embedding$Embeddings} is a significant topic within machine learning. This information was just to get you started using them as feature columns. diff --git a/tensorflow/docs_src/programmers_guide/graph_viz.md b/tensorflow/docs_src/guide/graph_viz.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/graph_viz.md rename to tensorflow/docs_src/guide/graph_viz.md diff --git a/tensorflow/docs_src/programmers_guide/graphs.md b/tensorflow/docs_src/guide/graphs.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/graphs.md rename to tensorflow/docs_src/guide/graphs.md index f0dd8def17fd6dfed241167a5ebb5be678152c16..e6246ef148d8a5ddea65be53f1bb32193d4845ad 100644 --- a/tensorflow/docs_src/programmers_guide/graphs.md +++ b/tensorflow/docs_src/guide/graphs.md @@ -93,7 +93,7 @@ to all API functions in the same context. For example: stored value. The @{tf.Variable} object also has methods such as @{tf.Variable.assign$`assign`} and @{tf.Variable.assign_add$`assign_add`} that create @{tf.Operation} objects that, when executed, update the stored value. - (See @{$programmers_guide/variables} for more information about variables.) + (See @{$guide/variables} for more information about variables.) * Calling @{tf.train.Optimizer.minimize} will add operations and tensors to the default graph that calculates gradients, and return a @{tf.Operation} that, diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/guide/index.md similarity index 71% rename from tensorflow/docs_src/programmers_guide/index.md rename to tensorflow/docs_src/guide/index.md index 0c2d4afb115c592c1925dde98b3a1a8c2a7ccad1..eefdb9ceae7a0d9ec45f476b0b3e82175830acc2 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/guide/index.md @@ -1,17 +1,17 @@ -# Programmer's Guide +# TensorFlow Guide The documents in this unit dive into the details of how TensorFlow works. The units are as follows: ## High Level APIs - * @{$programmers_guide/keras}, TensorFlow's high-level API for building and + * @{$guide/keras}, TensorFlow's high-level API for building and training deep learning models. - * @{$programmers_guide/eager}, an API for writing TensorFlow code + * @{$guide/eager}, an API for writing TensorFlow code imperatively, like you would use Numpy. - * @{$programmers_guide/estimators}, a high-level API that provides + * @{$guide/estimators}, a high-level API that provides fully-packaged models ready for large-scale training and production. - * @{$programmers_guide/datasets}, easy input pipelines to bring your data into + * @{$guide/datasets}, easy input pipelines to bring your data into your TensorFlow program. ## Estimators @@ -22,6 +22,7 @@ works. The units are as follows: design yourself. * @{$feature_columns}, which shows how an Estimator can handle a variety of input data types without changes to the model. +* @{$datasets_for_estimators} describes using tf.data with estimators. * @{$checkpoints}, which explains how to save training progress and resume where you left off. @@ -33,13 +34,13 @@ works. The units are as follows: ## Low Level APIs - * @{$programmers_guide/low_level_intro}, which introduces the + * @{$guide/low_level_intro}, which introduces the basics of how you can use TensorFlow outside of the high Level APIs. - * @{$programmers_guide/tensors}, which explains how to create, + * @{$guide/tensors}, which explains how to create, manipulate, and access Tensors--the fundamental object in TensorFlow. - * @{$programmers_guide/variables}, which details how + * @{$guide/variables}, which details how to represent shared, persistent state in your program. - * @{$programmers_guide/graphs}, which explains: + * @{$guide/graphs}, which explains: * dataflow graphs, which are TensorFlow's representation of computations as dependencies between operations. * sessions, which are TensorFlow's mechanism for running dataflow graphs @@ -49,19 +50,19 @@ works. The units are as follows: such as Estimators or Keras, the high-level API creates and manages graphs and sessions for you, but understanding graphs and sessions can still be helpful. - * @{$programmers_guide/saved_model}, which + * @{$guide/saved_model}, which explains how to save and restore variables and models. ## ML Concepts - * @{$programmers_guide/embedding}, which introduces the concept + * @{$guide/embedding}, which introduces the concept of embeddings, provides a simple example of training an embedding in TensorFlow, and explains how to view embeddings with the TensorBoard Embedding Projector. ## Debugging - * @{$programmers_guide/debugger}, which + * @{$guide/debugger}, which explains how to use the TensorFlow debugger (tfdbg). ## TensorBoard @@ -69,17 +70,17 @@ works. The units are as follows: TensorBoard is a utility to visualize different aspects of machine learning. The following guides explain how to use TensorBoard: - * @{$programmers_guide/summaries_and_tensorboard}, + * @{$guide/summaries_and_tensorboard}, which introduces TensorBoard. - * @{$programmers_guide/graph_viz}, which + * @{$guide/graph_viz}, which explains how to visualize the computational graph. - * @{$programmers_guide/tensorboard_histograms} which demonstrates the how to + * @{$guide/tensorboard_histograms} which demonstrates the how to use TensorBoard's histogram dashboard. ## Misc - * @{$programmers_guide/version_compat}, + * @{$guide/version_compat}, which explains backward compatibility guarantees and non-guarantees. - * @{$programmers_guide/faq}, which contains frequently asked + * @{$guide/faq}, which contains frequently asked questions about TensorFlow. diff --git a/tensorflow/docs_src/programmers_guide/keras.md b/tensorflow/docs_src/guide/keras.md similarity index 95% rename from tensorflow/docs_src/programmers_guide/keras.md rename to tensorflow/docs_src/guide/keras.md index c6aca7ebf4edd085e2b47492fd3a86578620492d..1d846df1044cd7100c083aa7d6b5be8f9cdd584e 100644 --- a/tensorflow/docs_src/programmers_guide/keras.md +++ b/tensorflow/docs_src/guide/keras.md @@ -19,7 +19,7 @@ fast prototyping, advanced research, and production, with three key advantages: [Keras API specification](https://keras.io){:.external}. This is a high-level API to build and train models that includes first-class support for TensorFlow-specific functionality, such as [eager execution](#eager_execution), -`tf.data` pipelines, and [Estimators](/programmers_guide/estimators). +`tf.data` pipelines, and [Estimators](./estimators.md). `tf.keras` makes TensorFlow easier to use without sacrificing flexibility and performance. @@ -35,8 +35,8 @@ from tensorflow import keras * The `tf.keras` version in the latest TensorFlow release might not be the same as the latest `keras` version from PyPI. Check `tf.keras.__version__`. * When [saving a model's weights](#weights_only), `tf.keras` defaults to the - [checkpoint format](/get_started/checkpoints). Pass `save_format='h5'` to use - HDF5. + [checkpoint format](./checkpoints.md). Pass `save_format='h5'` to + use HDF5. ## Build a simple model @@ -179,7 +179,7 @@ model.fit(data, labels, epochs=10, batch_size=32, ### Input tf.data datasets -Use the [Datasets API](/programmers_guide/datasets) to scale to large datasets +Use the [Datasets API](./datasets.md) to scale to large datasets or multi-device training. Pass a `tf.data.Dataset` instance to the `fit` method: @@ -221,7 +221,7 @@ To *evaluate* the inference-mode loss and metrics for the data provided: ```python model.evaluate(x, y, batch_size=32) -model.evaluate(dataset, steps=30 +model.evaluate(dataset, steps=30) ``` And to *predict* the output of the last layer in inference for the data provided, @@ -285,7 +285,7 @@ your own forward pass. Create layers in the `__init__` method and set them as attributes of the class instance. Define the forward pass in the `call` method. Model subclassing is particularly useful when -[eager execution](/programmers_guide/eager) is enabled since the forward pass +[eager execution](./eager.md) is enabled since the forward pass can be written imperatively. Key Point: Use the right API for the job. While model subclassing offers @@ -410,7 +410,7 @@ during training. You can write your own custom callback, or use the built-in * `tf.keras.callbacks.EarlyStopping`: Interrupt training when validation performance has stopped improving. * `tf.keras.callbacks.TensorBoard`: Monitor the model's behavior using - [TensorBoard](/programmers_guide/summaries_and_tensorboard). + [TensorBoard](./summaries_and_tensorboard.md). To use a `tf.keras.callbacks.Callback`, pass it to the model's `fit` method: @@ -442,8 +442,8 @@ model.load_weights('my_model') ``` By default, this saves the model's weights in the -[TensorFlow checkpoint](/get_started/checkpoints) file format. Weights can also -be saved to the Keras HDF5 format (the default for the multi-backend +[TensorFlow checkpoint](./checkpoints.md) file format. Weights can +also be saved to the Keras HDF5 format (the default for the multi-backend implementation of Keras): ```python @@ -509,7 +509,7 @@ model = keras.models.load_model('my_model.h5') ## Eager execution -[Eager execution](/programmers_guide/eager) is an imperative programming +[Eager execution](./eager.md) is an imperative programming environment that evaluates operations immediately. This is not required for Keras, but is supported by `tf.keras` and useful for inspecting your program and debugging. @@ -520,7 +520,7 @@ especially benefits *model subclassing* and building *custom layers*—the APIs that require you to write the forward pass as code (instead of the APIs that create models by assembling existing layers). -See the [eager execution guide](/programmers_guide/eager#build_a_model) for +See the [eager execution guide](./eager.md#build_a_model) for examples of using Keras models with custom training loops and `tf.GradientTape`. @@ -528,14 +528,14 @@ examples of using Keras models with custom training loops and `tf.GradientTape`. ### Estimators -The [Estimators](/programmers_guide/estimators) API is used for training models +The [Estimators](./estimators.md) API is used for training models for distributed environments. This targets industry use cases such as distributed training on large datasets that can export a model for production. A `tf.keras.Model` can be trained with the `tf.estimator` API by converting the model to an `tf.estimator.Estimator` object with `tf.keras.estimator.model_to_estimator`. See -[Creating Estimators from Keras models](/programmers_guide/estimators#creating_estimators_from_keras_models). +[Creating Estimators from Keras models](./estimators.md#creating_estimators_from_keras_models). ```python model = keras.Sequential([layers.Dense(10,activation='softmax'), @@ -548,8 +548,8 @@ model.compile(optimizer=tf.train.RMSPropOptimizer(0.001), estimator = keras.estimator.model_to_estimator(model) ``` -Note: Enable [eager execution](/programmers_guide/eager) for debugging -[Estimator input functions](/programmers_guide/premade_estimators#create_input_functions) +Note: Enable [eager execution](./eager.md) for debugging +[Estimator input functions](./premade_estimators.md#create_input_functions) and inspecting data. ### Multiple GPUs @@ -581,15 +581,6 @@ model.compile(loss='binary_crossentropy', optimizer=optimizer) model.summary() ``` -Convert the Keras model to a `tf.estimator.Estimator` instance: - -```python -keras_estimator = keras.estimator.model_to_estimator( - keras_model=model, - config=config, - model_dir='/tmp/model_dir') -``` - Define an *input pipeline*. The `input_fn` returns a `tf.data.Dataset` object used to distribute the data across multiple devices—with each device processing a slice of the input batch. @@ -615,6 +606,15 @@ strategy = tf.contrib.distribute.MirroredStrategy() config = tf.estimator.RunConfig(train_distribute=strategy) ``` +Convert the Keras model to a `tf.estimator.Estimator` instance: + +```python +keras_estimator = keras.estimator.model_to_estimator( + keras_model=model, + config=config, + model_dir='/tmp/model_dir') +``` + Finally, train the `Estimator` instance by providing the `input_fn` and `steps` arguments: diff --git a/tensorflow/docs_src/programmers_guide/leftnav_files b/tensorflow/docs_src/guide/leftnav_files similarity index 95% rename from tensorflow/docs_src/programmers_guide/leftnav_files rename to tensorflow/docs_src/guide/leftnav_files index 3bcf864e13db0cef40cec74ab872c807c2ec2fb0..357a2a1cb929e05be03fe19bd9dded8050149998 100644 --- a/tensorflow/docs_src/programmers_guide/leftnav_files +++ b/tensorflow/docs_src/guide/leftnav_files @@ -10,6 +10,7 @@ estimators.md: Introduction to Estimators premade_estimators.md custom_estimators.md feature_columns.md +datasets_for_estimators.md checkpoints.md ### Accelerators diff --git a/tensorflow/docs_src/programmers_guide/low_level_intro.md b/tensorflow/docs_src/guide/low_level_intro.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/low_level_intro.md rename to tensorflow/docs_src/guide/low_level_intro.md index 478e2bb70bc7f58156398c9f9fef4e76ba581e1a..665a5568b49a4cf3ee47d60617116f73e0db364f 100644 --- a/tensorflow/docs_src/programmers_guide/low_level_intro.md +++ b/tensorflow/docs_src/guide/low_level_intro.md @@ -303,7 +303,7 @@ while True: break ``` -For more details on Datasets and Iterators see: @{$programmers_guide/datasets}. +For more details on Datasets and Iterators see: @{$guide/datasets}. ## Layers diff --git a/tensorflow/docs_src/programmers_guide/premade_estimators.md b/tensorflow/docs_src/guide/premade_estimators.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/premade_estimators.md rename to tensorflow/docs_src/guide/premade_estimators.md index f6dd75eacab1c99215ab918a0854b0a33d0d9cca..3e910c1fe2ebfdffc25044f15b3558407d407ef1 100644 --- a/tensorflow/docs_src/programmers_guide/premade_estimators.md +++ b/tensorflow/docs_src/guide/premade_estimators.md @@ -78,10 +78,10 @@ provides a programming stack consisting of multiple API layers: We strongly recommend writing TensorFlow programs with the following APIs: -* @{$programmers_guide/estimators$Estimators}, which represent a complete model. +* @{$guide/estimators$Estimators}, which represent a complete model. The Estimator API provides methods to train the model, to judge the model's accuracy, and to generate predictions. -* @{$get_started/datasets_quickstart$Datasets}, which build a data input +* @{$guide/datasets_for_estimators}, which build a data input pipeline. The Dataset API has methods to load and manipulate data, and feed it into your model. The Dataset API meshes well with the Estimators API. @@ -173,7 +173,7 @@ example is an Iris Versicolor. An Estimator is TensorFlow's high-level representation of a complete model. It handles the details of initialization, logging, saving and restoring, and many other features so you can concentrate on your model. For more details see -@{$programmers_guide/estimators}. +@{$guide/estimators}. An Estimator is any class derived from @{tf.estimator.Estimator}. TensorFlow provides a collection of @@ -424,9 +424,7 @@ Now that you've gotten started writing TensorFlow programs, consider the following material: * @{$checkpoints$Checkpoints} to learn how to save and restore models. -* @{$get_started/datasets_quickstart$Datasets} to learn more about importing - data into your - model. +* @{$guide/datasets_for_estimators} to learn more about importing + data into your model. * @{$custom_estimators$Creating Custom Estimators} to learn how to write your own Estimator, customized for a particular problem. - diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/guide/saved_model.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/saved_model.md rename to tensorflow/docs_src/guide/saved_model.md index c6ef87c54a3bc37dbfc0553232a8e3d30f8ee2f6..acc3d3ca0b74f4898523e4af0452f65463d8b94f 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/guide/saved_model.md @@ -3,7 +3,7 @@ The @{tf.train.Saver} class provides methods to save and restore models. The @{tf.saved_model.simple_save} function is an easy way to build a @{tf.saved_model$saved model} suitable for serving. -[Estimators](@{$programmers_guide/estimators}) automatically save and restore +[Estimators](@{$guide/estimators}) automatically save and restore variables in the `model_dir`. ## Save and restore variables @@ -299,7 +299,7 @@ following: added attributes with defaults don't cause older model consumers to fail loading models regenerated with newer training binaries. -See [compatibility guidance](https://www.tensorflow.org/programmers_guide/version_compat) +See [compatibility guidance](./version_compat.md) for more information. ### Loading a SavedModel in Python @@ -794,11 +794,12 @@ Here's the syntax: ``` usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def SIGNATURE_DEF_KEY [--inputs INPUTS] - [--input_exprs INPUT_EXPRS] [--outdir OUTDIR] + [--input_exprs INPUT_EXPRS] + [--input_examples INPUT_EXAMPLES] [--outdir OUTDIR] [--overwrite] [--tf_debug] ``` -The `run` command provides the following two ways to pass inputs to the model: +The `run` command provides the following three ways to pass inputs to the model: * `--inputs` option enables you to pass numpy ndarray in files. * `--input_exprs` option enables you to pass Python expressions. @@ -847,7 +848,7 @@ dictionary is stored in the pickle file and the value corresponding to the *variable_name* will be used. -#### `--inputs_exprs` +#### `--input_exprs` To pass inputs through Python expressions, specify the `--input_exprs` option. This can be useful for when you don't have data @@ -869,7 +870,7 @@ example: (Note that the `numpy` module is already available to you as `np`.) -#### `--inputs_examples` +#### `--input_examples` To pass `tf.train.Example` as inputs, specify the `--input_examples` option. For each input key, it takes a list of dictionary, where each dictionary is an diff --git a/tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md b/tensorflow/docs_src/guide/summaries_and_tensorboard.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md rename to tensorflow/docs_src/guide/summaries_and_tensorboard.md diff --git a/tensorflow/docs_src/programmers_guide/tensorboard_histograms.md b/tensorflow/docs_src/guide/tensorboard_histograms.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/tensorboard_histograms.md rename to tensorflow/docs_src/guide/tensorboard_histograms.md diff --git a/tensorflow/docs_src/programmers_guide/tensors.md b/tensorflow/docs_src/guide/tensors.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/tensors.md rename to tensorflow/docs_src/guide/tensors.md index 1248c3cabe23c8d5f200fc1bf46e60851ba532a6..7227260f1a4ee08309f42d21bab8eaa3c77e3297 100644 --- a/tensorflow/docs_src/programmers_guide/tensors.md +++ b/tensorflow/docs_src/guide/tensors.md @@ -26,7 +26,7 @@ some cases it's only possible to find the shape of a tensor at graph execution time. Some types of tensors are special, and these will be covered in other -units of the Programmer's guide. The main ones are: +units of the TensorFlow guide. The main ones are: * `tf.Variable` * `tf.constant` @@ -230,7 +230,7 @@ yet_another = tf.reshape(matrixAlt, [13, 2, -1]) # ERROR! ## Data types In addition to dimensionality, Tensors have a data type. Refer to the -`tf.DataType` page in the programmer's guide for a full list of the data types. +`tf.DType` page for a complete list of the data types. It is not possible to have a `tf.Tensor` with more than one data type. It is possible, however, to serialize arbitrary data structures as `string`s and store diff --git a/tensorflow/docs_src/programmers_guide/using_gpu.md b/tensorflow/docs_src/guide/using_gpu.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/using_gpu.md rename to tensorflow/docs_src/guide/using_gpu.md diff --git a/tensorflow/docs_src/programmers_guide/using_tpu.md b/tensorflow/docs_src/guide/using_tpu.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/using_tpu.md rename to tensorflow/docs_src/guide/using_tpu.md index 44aabf05571bb7f325a5d642f06362e0088607d2..41d80d9d60694c87675f07d8045713d9a117c7f1 100644 --- a/tensorflow/docs_src/programmers_guide/using_tpu.md +++ b/tensorflow/docs_src/guide/using_tpu.md @@ -171,7 +171,7 @@ This section details the changes you must make to the model function During regular usage TensorFlow attempts to determine the shapes of each `tf.Tensor` during graph construction. During execution any unknown shape dimensions are determined dynamically, -see @{$programmers_guide/tensors#shape$Tensor Shapes} for more details. +see @{$guide/tensors#shape$Tensor Shapes} for more details. To run on Cloud TPUs TensorFlow models are compiled using @{$xla$XLA}. XLA uses a similar system for determining shapes at compile time. XLA requires @@ -195,7 +195,7 @@ TPU. Build your evaluation metrics dictionary in a stand-alone `metric_fn`. - + Evaluation metrics are an essential part of training a model. These are fully supported on Cloud TPUs, but with a slightly different syntax. diff --git a/tensorflow/docs_src/programmers_guide/variables.md b/tensorflow/docs_src/guide/variables.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/variables.md rename to tensorflow/docs_src/guide/variables.md diff --git a/tensorflow/docs_src/programmers_guide/version_compat.md b/tensorflow/docs_src/guide/version_compat.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/version_compat.md rename to tensorflow/docs_src/guide/version_compat.md diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md index 55bc0f64e799ecb7115cc656c8a08ec1ce2a6108..2c126df5aa6263127fcdd7a9b01efcbaf3c15c46 100644 --- a/tensorflow/docs_src/install/install_go.md +++ b/tensorflow/docs_src/install/install_go.md @@ -6,7 +6,7 @@ a Go application. This guide explains how to install and set up the [TensorFlow Go package](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go). Warning: The TensorFlow Go API is *not* covered by the TensorFlow -[API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). +[API stability guarantees](../guide/version_semantics.md). ## Supported Platforms diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 637231da1252097a9a143edea23f3248c7cf2eb6..692dfc9cefe89b9d39a18a55284f8b521d620100 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -7,7 +7,7 @@ Java application. This guide explains how to install and use it in a Java application. Warning: The TensorFlow Java API is *not* covered by the TensorFlow -[API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). +[API stability guarantees](../guide/version_semantics.md). ## Supported Platforms diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index c8d706cf3c232d1ea91265bd7ad38d5227c440f0..c573acaf458a5c0bb52b7c3b314bd52ae60c4577 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -489,13 +489,7 @@ TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the following: - -* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) -* @{$get_started/eager} - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. +To learn more, see [Get Started with TensorFlow](https://www.tensorflow.org/get_started). ## TensorFlow GPU support diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 9d01271c5a0beebf75be9e32682583ddc4a666b1..584f1e2e35caff32a4f8aea5ab5fe94114470219 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -403,11 +403,7 @@ writing TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the -[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course). - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. +To learn more, see [Get Started with TensorFlow](https://www.tensorflow.org/get_started). ## Common installation problems diff --git a/tensorflow/docs_src/install/install_raspbian.md b/tensorflow/docs_src/install/install_raspbian.md index 2f425162a1c63f084702727b5280ed266196b955..0caab6d335544bfc291894a79f9ed0441eb03561 100644 --- a/tensorflow/docs_src/install/install_raspbian.md +++ b/tensorflow/docs_src/install/install_raspbian.md @@ -230,11 +230,7 @@ problems, despite the log message. If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the [Machine Learning Crash -Course](https://developers.google.com/machine-learning/crash-course). - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. +To learn more, see [Get Started with TensorFlow](https://www.tensorflow.org/get_started). ## Common installation problems diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index dc6c1e36fc237c2a160887e6417e7f373008309e..a641dc3a6f5436b3c321a0216fca7ac90d554b63 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -289,17 +289,27 @@ Note: If you're only interested in building the libraries for the TensorFlow C or Java APIs, see [Build the C or Java libraries](#BuildCorJava), you do not need to build the pip package in that case. -To build a pip package for TensorFlow with CPU-only support, -you would typically invoke the following command: +### CPU-only support + +To build a pip package for TensorFlow with CPU-only support: + +
+$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
+
+ +To build a pip package for TensorFlow with CPU-only support for the IntelĀ® MKL-DNN:
-$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
+$ bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package
 
-To build a pip package for TensorFlow with GPU support, -invoke the following command: +### GPU support + +To build a pip package for TensorFlow with GPU support: -
$ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package 
+
+$ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
+
**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow website are built with gcc 4, which uses the older ABI. To @@ -362,7 +372,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/eager}. +To learn more, see [Get Started with TensorFlow](https://www.tensorflow.org/get_started). If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index 6c4f5b85ab2facdb274e9bdd36f6edb9ad79ba4b..7fe94f0bc3850b7210e83f746f8f8fd5b343cbd3 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -157,12 +157,7 @@ TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the -[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course). - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. - +To learn more, see [Get Started with TensorFlow](https://www.tensorflow.org/get_started). ## Common installation problems diff --git a/tensorflow/docs_src/install/leftnav_files b/tensorflow/docs_src/install/leftnav_files index e523e06f67aad508238ee0965f34ebe16c77bf90..ace275c0e82b794708bfc63c0e61d6bb3251a152 100644 --- a/tensorflow/docs_src/install/leftnav_files +++ b/tensorflow/docs_src/install/leftnav_files @@ -4,6 +4,7 @@ index.md install_linux.md: Ubuntu install_mac.md: MacOS install_windows.md: Windows +install_raspbian.md: Raspbian install_sources.md: From source >>> migration.md diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/docs_src/mobile/tflite/demo_android.md index 6f9893f8f18b4d94dee887ce797f4a9440ed1a8a..fdf0bcf3c1135f0e702c7dda4d1d608a26169470 100644 --- a/tensorflow/docs_src/mobile/tflite/demo_android.md +++ b/tensorflow/docs_src/mobile/tflite/demo_android.md @@ -1,7 +1,7 @@ # Android Demo App An example Android application using TensorFLow Lite is available -[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). +[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo). The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. To run the demo, a device running Android 5.0 ( API 21) or higher is required. diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index f7e116bf0f85fe94b2167eca8b623207432b38e9..4c4f3f39348f59aa018d19d4a7368f09bcef89ed 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1308,12 +1308,10 @@ See also : : : parameters of type T and M of : : : : arbitrary type : | `dimensions` | `int64` array | array of map dimensions | -| `static_operands` | sequence of M `XlaOp`s | M arrays of arbitrary type | Applies a scalar function over the given `operands` arrays, producing an array of the same dimensions where each element is the result of the mapped function -applied to the corresponding elements in the input arrays with `static_operands` -given as additional input to `computation`. +applied to the corresponding elements in the input arrays. The mapped function is an arbitrary computation with the restriction that it has N inputs of scalar type `T` and a single output with type `S`. The output has @@ -2012,13 +2010,35 @@ Slice(b, {2, 1}, {4, 3}) produces: See also [`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). -Sorts the elements in the operand. +There are two versions of the Sort instruction: a single-operand and a +two-operand version. `Sort(operand)` +Arguments | Type | Semantics +--------- | ------- | -------------------- +`operand` | `XlaOp` | The operand to sort. + +Sorts the elements in the operand in ascending order. The operand must be rank-1. +If the operand's elements have floating point type, and the operand contains +NaN elements, the order of elements in the output is implementation-defined. + +`Sort(key, value)` + +Sorts both the key and the value operands. The keys are sorted as in the +single-operand version. The values are sorted according to the order of their +corresponding keys. For example, if the inputs are `keys = [3, 1]` and +`values = [42, 50]`, then the output of the sort is the tuple `{[1, 3], [50, 42]}`. +The sort is not guaranteed to be stable, that is, if the keys array contains +duplicates, the order of their corresponding values may not be preserved. + Arguments | Type | Semantics --------- | ------- | ------------------- -`operand` | `XlaOp` | The operand to sort +`keys` | `XlaOp` | The sort keys. +`values` | `XlaOp` | The values to sort. + +The `keys` and `values` operand must both be rank-1, and must have the same +dimensions, but may have different element types. ## Transpose diff --git a/tensorflow/docs_src/tutorials/deep_cnn.md b/tensorflow/docs_src/tutorials/deep_cnn.md index 6a4c9a9b0727208a158b1b57d13ca70290961ec2..44a32d9d1dcbd7d4be7a2063e9c5ae4affffe487 100644 --- a/tensorflow/docs_src/tutorials/deep_cnn.md +++ b/tensorflow/docs_src/tutorials/deep_cnn.md @@ -268,7 +268,7 @@ in `cifar10_input.py`. `cifar10_train.py` periodically @{tf.train.Saver$saves} all model parameters in -@{$programmers_guide/saved_model$checkpoint files} +@{$guide/saved_model$checkpoint files} but it does *not* evaluate the model. The checkpoint file will be used by `cifar10_eval.py` to measure the predictive performance (see [Evaluating a Model](#evaluating-a-model) below). diff --git a/tensorflow/docs_src/tutorials/index.md b/tensorflow/docs_src/tutorials/index.md index af01d3eaa12157f82c981de005708509f6652cca..6bd3a3a897d9cc11e9172e4ccde6fcad4f075ad1 100644 --- a/tensorflow/docs_src/tutorials/index.md +++ b/tensorflow/docs_src/tutorials/index.md @@ -2,9 +2,8 @@ This section contains tutorials demonstrating how to do specific tasks -in TensorFlow. If you are new to TensorFlow, we recommend reading the -documents in the "@{$get_started$Get Started}" section before reading -these tutorials. +in TensorFlow. If you are new to TensorFlow, we recommend reading +[Get Started with TensorFlow](/get_started/). ## Images diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/layers.md index 0f17899dae7ccd8686ac159548dec303401b8ad4..791909f5fd5be2913af1a093d967c9fbb6af89a3 100644 --- a/tensorflow/docs_src/tutorials/layers.md +++ b/tensorflow/docs_src/tutorials/layers.md @@ -470,51 +470,18 @@ as the loss metric. The following code calculates cross entropy when the model runs in either `TRAIN` or `EVAL` mode: ```python -onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) -loss = tf.losses.softmax_cross_entropy( - onehot_labels=onehot_labels, logits=logits) +loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) ``` Let's take a closer look at what's happening above. -Our `labels` tensor contains a list of predictions for our examples, e.g. `[1, -9, ...]`. In order to calculate cross-entropy, first we need to convert `labels` -to the corresponding -[one-hot encoding](https://www.quora.com/What-is-one-hot-encoding-and-when-is-it-used-in-data-science): +Our `labels` tensor contains a list of prediction indices for our examples, e.g. `[1, +9, ...]`. `logits` contains the linear outputs of our last layer. -```none -[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - ...] -``` - -We use the @{tf.one_hot} function -to perform this conversion. `tf.one_hot()` has two required arguments: - -* `indices`. The locations in the one-hot tensor that will have "on - values"—i.e., the locations of `1` values in the tensor shown above. -* `depth`. The depth of the one-hot tensor—i.e., the number of target classes. - Here, the depth is `10`. +`tf.losses.sparse_softmax_cross_entropy`, calculates the softmax crossentropy +(aka: categorical crossentropy, negative log-likelihood) from these two inputs +in an efficient, numerically stable way. -The following code creates the one-hot tensor for our labels, `onehot_labels`: - -```python -onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) -``` - -Because `labels` contains a series of values from 0–9, `indices` is just our -`labels` tensor, with values cast to integers. The `depth` is `10` because we -have 10 possible target classes, one for each digit. - -Next, we compute cross-entropy of `onehot_labels` and the softmax of the -predictions from our logits layer. `tf.losses.softmax_cross_entropy()` takes -`onehot_labels` and `logits` as arguments, performs softmax activation on -`logits`, calculates cross-entropy, and returns our `loss` as a scalar `Tensor`: - -```python -loss = tf.losses.softmax_cross_entropy( - onehot_labels=onehot_labels, logits=logits) -``` ### Configure the Training Op @@ -627,7 +594,7 @@ operation earlier when we generated the probabilities in `cnn_model_fn`. > argument, TensorFlow will assign a default name. A couple easy ways to > discover the names applied to operations are to visualize your graph on > @{$graph_viz$TensorBoard}) or to enable the -> @{$programmers_guide/debugger$TensorFlow Debugger (tfdbg)}. +> @{$guide/debugger$TensorFlow Debugger (tfdbg)}. Next, we create the `LoggingTensorHook`, passing `tensors_to_log` to the `tensors` argument. We set `every_n_iter=50`, which specifies that probabilities diff --git a/tensorflow/docs_src/tutorials/leftnav_files b/tensorflow/docs_src/tutorials/leftnav_files index 888052428f951fa1a7cbd9c6d35497a056387097..eadd410d0812cfecbcb7cb01550e2f7e7f9da0db 100644 --- a/tensorflow/docs_src/tutorials/leftnav_files +++ b/tensorflow/docs_src/tutorials/leftnav_files @@ -3,10 +3,11 @@ index.md ### Images layers.md: MNIST image_recognition.md: Image Recognition -image_retraining.md: Image Retraining +/hub/tutorials/image_retraining.md: Image Retraining deep_cnn.md ### Sequences +/hub/tutorials/text_classification_with_tf_hub: Text Classification recurrent.md seq2seq.md: Neural Machine Translation recurrent_quickdraw.md: Drawing Classification diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 07f096418f53219c9ec7000a4560d78a3ff609e1..f327b645f58f35cedd27baa8ab521e334c8e7b15 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow camera demo app for Android. +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index 307eede5c03780e9244b035f020fc7846290d4d9..740224744860fdd76bea9c4531242a4976b20784 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -17,7 +17,7 @@ This version is like fully_connected_feed.py but uses data converted to a TFRecords file containing tf.train.Example protocol buffers. See: -https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files +https://www.tensorflow.org/guide/reading_data#reading_from_files for context. YOU MUST run convert_to_records before running this (but you only need to diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go index 35b0cb352e7a5c1ca2e465720cd4dc125f166675..ea8af221aeef3bf1d2edeab4372ae00f0cc7e92d 100644 --- a/tensorflow/go/attrs_test.go +++ b/tensorflow/go/attrs_test.go @@ -28,7 +28,7 @@ func TestOperationAttrs(t *testing.T) { i := 0 makeConst := func(v interface{}) Output { op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v) - i += 1 + i++ if err != nil { t.Fatal(err) } @@ -71,6 +71,7 @@ func TestOperationAttrs(t *testing.T) { "boundaries": []float32(nil), }, }, + /* TODO(ashankar): debug this issue and add it back later. { Name: "list(type),list(shape)", Type: "InfeedEnqueueTuple", @@ -111,6 +112,7 @@ func TestOperationAttrs(t *testing.T) { "device_ordinal": int64(0), }, }, + */ { Name: "list(int),int", Type: "StringToHashBucketStrong", diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index a5224fbda04fdfb2460fce96efffb6eab4f08551..7f1f0970a6fd697419b4158f3a6517bca5bbe10e 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -2990,34 +2990,55 @@ func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (outp return output } -// Creates a sequence of numbers. +// Concatenates tensors along one dimension. // -// This operation creates a sequence of numbers that begins at `start` and -// extends by increments of `delta` up to but not including `limit`. +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. // -// For example: +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Concat", + Input: []tf.Input{ + concat_dim, tf.OutputList(values), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts a flat index or array of flat indices into a tuple of // -// ``` -// # 'start' is 3 -// # 'limit' is 18 -// # 'delta' is 3 -// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] -// ``` +// coordinate arrays. +// +// @compatibility(numpy) +// Equivalent to np.unravel_index +// @end_compatibility // // Arguments: -// start: 0-D (scalar). First entry in the sequence. -// limit: 0-D (scalar). Upper limit of sequence, exclusive. -// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. +// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the +// flattened version of an array of dimensions dims. +// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling +// indices. // -// Returns 1-D. -func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { +// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the +// same shape as the indices array. +func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Range", + Type: "UnravelIndex", Input: []tf.Input{ - start, limit, delta, + indices, dims, }, } op := scope.AddOperation(opspec) @@ -3923,24 +3944,6 @@ func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Returns x + y element-wise. -// -// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Add", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // NthElementAttr is an optional argument to NthElement. type NthElementAttr func(optionalAttr) @@ -4684,6 +4687,24 @@ func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) return op.Output(0) } +// Returns x + y element-wise. +// +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Add", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes square of x element-wise. // // I.e., \\(y = x * x = x^2\\). @@ -7789,121 +7810,6 @@ func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (o return op.Output(0) } -// LRNGradAttr is an optional argument to LRNGrad. -type LRNGradAttr func(optionalAttr) - -// LRNGradDepthRadius sets the optional depth_radius attribute to value. -// -// value: A depth radius. -// If not specified, defaults to 5 -func LRNGradDepthRadius(value int64) LRNGradAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } -} - -// LRNGradBias sets the optional bias attribute to value. -// -// value: An offset (usually > 0 to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNGradBias(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} - -// LRNGradAlpha sets the optional alpha attribute to value. -// -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNGradAlpha(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// LRNGradBeta sets the optional beta attribute to value. -// -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNGradBeta(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["beta"] = value - } -} - -// Gradients for Local Response Normalization. -// -// Arguments: -// input_grads: 4-D with shape `[batch, height, width, channels]`. -// input_image: 4-D with shape `[batch, height, width, channels]`. -// output_image: 4-D with shape `[batch, height, width, channels]`. -// -// Returns The gradients for LRN. -func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LRNGrad", - Input: []tf.Input{ - input_grads, input_image, output_image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AnyAttr is an optional argument to Any. -type AnyAttr func(optionalAttr) - -// AnyKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AnyKeepDims(value bool) AnyAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the "logical or" of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Any", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. type ResourceApplyFtrlAttr func(optionalAttr) @@ -8284,27 +8190,29 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe return op.Output(0) } -// Reads the value of a variable. +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// The tensor returned by this operation is immutable. +// The hash function is deterministic on the content of the string within the +// process. // -// The value returned by this operation is guaranteed to be influenced by all the -// writes on which this operation depends directly or indirectly, and to not be -// influenced by any of the writes which depend directly or indirectly on this -// operation. +// Note that the hash function may change from time to time. +// This functionality will be deprecated and it's recommended to use +// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. // // Arguments: -// resource: handle to the resource in which to store the variable. -// dtype: the dtype of the value. -func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { +// +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "ReadVariableOp", + Type: "StringToHashBucket", Input: []tf.Input{ - resource, + string_tensor, }, Attrs: attrs, } @@ -8312,126 +8220,88 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value return op.Output(0) } -// Computes tan of x element-wise. -func Tan(scope *Scope, x tf.Output) (y tf.Output) { +// Computes gradients for the exponential linear (Elu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Elu operation. +// outputs: The outputs of the corresponding Elu operation. +// +// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, +// `gradients` otherwise. +func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Tan", + Type: "EluGrad", Input: []tf.Input{ - x, + gradients, outputs, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Updates the tree ensemble by either adding a layer to the last tree being grown -// -// or by starting a new tree. +// Creates a dataset that contains `count` elements from the `input_dataset`. // // Arguments: -// tree_ensemble_handle: Handle to the ensemble variable. -// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of -// the feature that will be used in the split. -// node_ids: List of rank 1 tensors representing the nodes for which this feature -// has a split. -// gains: List of rank 1 tensors representing the gains for each of the feature's -// split. -// thresholds: List of rank 1 tensors representing the thesholds for each of the -// feature's split. -// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of -// the feature's splits. Will be added to the previous node values to constitute -// the values of the left nodes. -// right_node_contribs: List of rank 2 tensors with right leaf contribs for each -// of the feature's splits. Will be added to the previous node values to constitute -// the values of the right nodes. -// max_depth: Max depth of the tree to build. -// learning_rate: shrinkage const for each new tree. -// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. // -// Returns the created operation. -func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { +// count: A scalar representing the number of elements from the `input_dataset` +// that should be taken. A value of `-1` indicates that all of `input_dataset` +// is taken. +// +// +func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pruning_mode": pruning_mode} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "BoostedTreesUpdateEnsemble", + Type: "TakeDataset", Input: []tf.Input{ - tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, + input_dataset, count, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. -type ResourceSparseApplyFtrlAttr func(optionalAttr) - -// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// Reads the value of a variable. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// The tensor returned by this operation is immutable. // -// That is for rows we have grad for, we update var, accum and linear as follows: -// accum_new = accum + grad * grad -// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// The value returned by this operation is guaranteed to be influenced by all the +// writes on which this operation depends directly or indirectly, and to not be +// influenced by any of the writes which depend directly or indirectly on this +// operation. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { +// resource: handle to the resource in which to store the variable. +// dtype: the dtype of the value. +func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrl", + Type: "ReadVariableOp", Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, lr_power, + resource, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns which elements of x are Inf. -// -// @compatibility(numpy) -// Equivalent to np.isinf -// @end_compatibility -func IsInf(scope *Scope, x tf.Output) (y tf.Output) { +// Computes tan of x element-wise. +func Tan(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IsInf", + Type: "Tan", Input: []tf.Input{ x, }, @@ -8440,147 +8310,175 @@ func IsInf(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. -// -// N is the size of the segment being reduced. +// Updates the tree ensemble by either adding a layer to the last tree being grown // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// or by starting a new tree. // // Arguments: +// tree_ensemble_handle: Handle to the ensemble variable. +// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of +// the feature that will be used in the split. +// node_ids: List of rank 1 tensors representing the nodes for which this feature +// has a split. +// gains: List of rank 1 tensors representing the gains for each of the feature's +// split. +// thresholds: List of rank 1 tensors representing the thesholds for each of the +// feature's split. +// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of +// the feature's splits. Will be added to the previous node values to constitute +// the values of the left nodes. +// right_node_contribs: List of rank 2 tensors with right leaf contribs for each +// of the feature's splits. Will be added to the previous node values to constitute +// the values of the right nodes. +// max_depth: Max depth of the tree to build. +// learning_rate: shrinkage const for each new tree. +// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns the created operation. +func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"pruning_mode": pruning_mode} opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtN", + Type: "BoostedTreesUpdateEnsemble", Input: []tf.Input{ - data, indices, segment_ids, + tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. -// -// This Op does not require `a_indices` be sorted in standard lexicographic order. +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) + +// EncodeJpegFormat sets the optional format attribute to value. // -// Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. -// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. -// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. -// b: `ndims`-D Tensor. With shape `a_shape`. -func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseTensorDenseAdd", - Input: []tf.Input{ - a_indices, a_values, a_shape, b, - }, +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["format"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) - -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// EncodeJpegQuality sets the optional quality attribute to value. // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { return func(m optionalAttr) { - m["dtype"] = value + m["quality"] = value } } -// Outputs deterministic pseudorandom values from a truncated normal distribution. -// -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// EncodeJpegProgressive sets the optional progressive attribute to value. // -// Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) +// value: If True, create a JPEG that loads progressively (coarse to fine). +// If not specified, defaults to false +func EncodeJpegProgressive(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["progressive"] = value } - opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", - Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, +} + +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. +// +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["optimize_size"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// RestoreSliceAttr is an optional argument to RestoreSlice. -type RestoreSliceAttr func(optionalAttr) +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value + } +} -// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. // -// value: Index of file to open first if multiple files match -// `file_pattern`. See the documentation for `Restore`. -// If not specified, defaults to -1 -func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { return func(m optionalAttr) { - m["preferred_shard"] = value + m["density_unit"] = value } } -// Restores a tensor from checkpoint files. +// EncodeJpegXDensity sets the optional x_density attribute to value. // -// This is like `Restore` except that restored tensor can be listed as filling -// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the -// larger tensor and the slice that the restored tensor covers. +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value + } +} + +// EncodeJpegYDensity sets the optional y_density attribute to value. // -// The `shape_and_slice` input has the same format as the -// elements of the `shapes_and_slices` input of the `SaveSlices` op. +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value + } +} + +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. +// +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["xmp_metadata"] = value + } +} + +// JPEG-encode an image. +// +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: +// +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. +// +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: +// +// * 1: Output a grayscale image. +// * 3: Output an RGB image. // // Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// shape_and_slice: Scalar. The shapes and slice specifications to use when -// restoring a tensors. -// dt: The type of the tensor to be restored. +// image: 3-D with shape `[height, width, channels]`. // -// Returns The restored tensor. -func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dt": dt} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RestoreSlice", + Type: "EncodeJpeg", Input: []tf.Input{ - file_pattern, tensor_name, shape_and_slice, + image, }, Attrs: attrs, } @@ -8588,57 +8486,59 @@ func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, s return op.Output(0) } -// Divides sparse updates into the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] /= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] /= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions multiply. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
+// MultinomialAttr is an optional argument to Multinomial. +type MultinomialAttr func(optionalAttr) + +// MultinomialSeed sets the optional seed attribute to value. // -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// value: If either seed or seed2 is set to be non-zero, the internal random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func MultinomialSeed(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// MultinomialSeed2 sets the optional seed2 attribute to value. // -// Returns the created operation. -func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func MultinomialSeed2(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed2"] = value } - opspec := tf.OpSpec{ - Type: "ResourceScatterDiv", - Input: []tf.Input{ - resource, indices, updates, - }, +} + +// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { + return func(m optionalAttr) { + m["output_dtype"] = value } - return scope.AddOperation(opspec) } -// Mutually reduces multiple tensors of identical type and shape. -func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) { +// Draws samples from a multinomial distribution. +// +// Arguments: +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. +// +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "CollectiveReduce", + Type: "Multinomial", Input: []tf.Input{ - input, + logits, num_samples, }, Attrs: attrs, } @@ -8646,31 +8546,35 @@ func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key return op.Output(0) } -// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. -type StatelessRandomNormalAttr func(optionalAttr) +// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. +type ResourceSparseApplyAdagradDAAttr func(optionalAttr) -// StatelessRandomNormalDtype sets the optional dtype attribute to value. +// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { return func(m optionalAttr) { - m["dtype"] = value + m["use_locking"] = value } } -// Outputs deterministic pseudorandom values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. -// -// The outputs are a deterministic function of `shape` and `seed`. +// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. // -// Returns Random values with specified shape. -func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { +// Returns the created operation. +func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -8679,186 +8583,166 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessRandomNormal", + Type: "ResourceSparseApplyAdagradDA", Input: []tf.Input{ - shape, seed, + var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] = min(ref[indices, ...], updates[...]) -// -// # Vector indices (for each i) -// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions are combined. +// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. +type ResourceSparseApplyFtrlAttr func(optionalAttr) + +// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. // -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. // -//
-// -//
+// That is for rows we have grad for, we update var, accum and linear as follows: +// accum_new = accum + grad * grad +// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. // // Returns the created operation. -func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ResourceScatterMin", + Type: "ResourceSparseApplyFtrl", Input: []tf.Input{ - resource, indices, updates, + var_, accum, linear, grad, indices, lr, l1, l2, lr_power, }, + Attrs: attrs, } return scope.AddOperation(opspec) } -// Reshapes a quantized tensor as per the Reshape op. -// -// ``` -// -// Arguments: -// -// shape: Defines the shape of the output tensor. -// input_min: The minimum value of the input. -// input_max: The maximum value of the input. +// Returns which elements of x are Inf. // -// Returns This value is copied from input_min.This value is copied from input_max. -func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// @compatibility(numpy) +// Equivalent to np.isinf +// @end_compatibility +func IsInf(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QuantizedReshape", + Type: "IsInf", Input: []tf.Input{ - tensor, shape, input_min, input_max, + x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Returns the truth value of (x != y) element-wise. +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. // -// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// N is the size of the segment being reduced. +// +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NotEqual", + Type: "SparseSegmentSqrtN", Input: []tf.Input{ - x, y, + data, indices, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Inverse 3D real-valued fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 3 dimensions of `input`. -// -// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 3 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. +// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. // -// Along each axis `IRFFT3D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// This Op does not require `a_indices` be sorted in standard lexicographic order. // // Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 3D real Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.irfftn with 3 dimensions. -// @end_compatibility -func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. +// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. +// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. +// b: `ndims`-D Tensor. With shape `a_shape`. +func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IRFFT3D", + Type: "SparseTensorDenseAdd", Input: []tf.Input{ - input, fft_length, + a_indices, a_values, a_shape, b, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// StringSplitAttr is an optional argument to StringSplit. -type StringSplitAttr func(optionalAttr) +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) -// StringSplitSkipEmpty sets the optional skip_empty attribute to value. +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. // -// value: A `bool`. If `True`, skip the empty strings from the result. -// If not specified, defaults to true -func StringSplitSkipEmpty(value bool) StringSplitAttr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { return func(m optionalAttr) { - m["skip_empty"] = value + m["dtype"] = value } } -// Split elements of `input` based on `delimiter` into a `SparseTensor`. -// -// Let N be the size of source (typically N will be the batch size). Split each -// element of `input` based on `delimiter` and return a `SparseTensor` -// containing the splitted tokens. Empty tokens are ignored. -// -// `delimiter` can be empty, or a string of split characters. If `delimiter` is an -// empty string, each element of `input` is split into individual single-byte -// character strings, including splitting of UTF-8 multibyte sequences. Otherwise -// every character of `delimiter` is a potential split point. +// Outputs deterministic pseudorandom values from a truncated normal distribution. // -// For example: -// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output -// will be +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. // -// indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// shape = [2, 3] -// values = ['hello', 'world', 'a', 'b', 'c'] +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// input: 1-D. Strings to split. -// delimiter: 0-D. Delimiter characters (bytes), or empty string. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse -// tensor, where the first value is N and the second value is the maximum number -// of tokens in a single input entry. -func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { +// Returns Random values with specified shape. +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8867,134 +8751,151 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional .. a(attrs) } opspec := tf.OpSpec{ - Type: "StringSplit", + Type: "StatelessTruncatedNormal", Input: []tf.Input{ - input, delimiter, + shape, seed, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. -type ResourceSparseApplyMomentumAttr func(optionalAttr) - -// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} +// RestoreSliceAttr is an optional argument to RestoreSlice. +type RestoreSliceAttr func(optionalAttr) -// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. +// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. // -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. -// If not specified, defaults to false -func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { +// value: Index of file to open first if multiple files match +// `file_pattern`. See the documentation for `Restore`. +// If not specified, defaults to -1 +func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { return func(m optionalAttr) { - m["use_nesterov"] = value + m["preferred_shard"] = value } } -// Update relevant entries in '*var' and '*accum' according to the momentum scheme. -// -// Set use_nesterov = True if you want to use Nesterov momentum. +// Restores a tensor from checkpoint files. // -// That is for rows we have grad for, we update var and accum as follows: +// This is like `Restore` except that restored tensor can be listed as filling +// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the +// larger tensor and the slice that the restored tensor covers. // -// accum = accum * momentum + grad -// var -= lr * accum +// The `shape_and_slice` input has the same format as the +// elements of the `shapes_and_slices` input of the `SaveSlices` op. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// momentum: Momentum. Must be a scalar. +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// shape_and_slice: Scalar. The shapes and slice specifications to use when +// restoring a tensors. +// dt: The type of the tensor to be restored. // -// Returns the created operation. -func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { +// Returns The restored tensor. +func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dt": dt} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyMomentum", + Type: "RestoreSlice", Input: []tf.Input{ - var_, accum, lr, grad, indices, momentum, + file_pattern, tensor_name, shape_and_slice, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the complex conjugate of a complex number. +// Divides sparse updates into the variable referenced by `resource`. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// complex numbers that are the complex conjugate of each element in `input`. The -// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the -// real part and *b* is the imaginary part. +// This operation computes // -// The complex conjugate returned by this operation is of the form \\(a - bj\\). +// # Scalar indices +// ref[indices, ...] /= updates[...] // -// For example: +// # Vector indices (for each i) +// ref[indices[i], ...] /= updates[i, ...] // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] -// ``` -func Conj(scope *Scope, input tf.Output) (output tf.Output) { +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions multiply. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Conj", + Type: "ResourceScatterDiv", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// Mutually reduces multiple tensors of identical type and shape. +func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} + opspec := tf.OpSpec{ + Type: "CollectiveReduce", Input: []tf.Input{ input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeBilinearAttr is an optional argument to ResizeBilinear. -type ResizeBilinearAttr func(optionalAttr) +// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. +type StatelessRandomNormalAttr func(optionalAttr) -// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. +// StatelessRandomNormalDtype sets the optional dtype attribute to value. // -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["dtype"] = value } } -// Resize `images` to `size` using bilinear interpolation. +// Outputs deterministic pseudorandom values from a normal distribution. // -// Input images can be of different types but output images are always float. +// The generated values will have mean 0 and standard deviation 1. +// +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { +// Returns Random values with specified shape. +func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -9003,9 +8904,9 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ... a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBilinear", + Type: "StatelessRandomNormal", Input: []tf.Input{ - images, size, + shape, seed, }, Attrs: attrs, } @@ -9013,128 +8914,207 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ... return op.Output(0) } -// Computes softsign: `features / (abs(features) + 1)`. -func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { +// MaxPoolAttr is an optional argument to MaxPool. +type MaxPoolAttr func(optionalAttr) + +// MaxPoolDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolDataFormat(value string) MaxPoolAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs max pooling on the input. +// +// Arguments: +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor. +func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Softsign", + Type: "MaxPool", Input: []tf.Input{ - features, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a TensorList which, when stacked, has the value of `tensor`. +// SparseMatMulAttr is an optional argument to SparseMatMul. +type SparseMatMulAttr func(optionalAttr) + +// SparseMatMulTransposeA sets the optional transpose_a attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeA(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// SparseMatMulTransposeB sets the optional transpose_b attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeB(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["a_is_sparse"] = value + } +} + +// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["b_is_sparse"] = value + } +} + +// Multiply matrix "a" by matrix "b". // -// Each tensor in the result list corresponds to one row of the input tensor. +// The inputs must be two-dimensional matrices and the inner dimension of "a" must +// match the outer dimension of "b". This op is optimized for the case where at +// least one of "a" or "b" is sparse. The breakeven for using this versus a dense +// matrix multiply on one platform was 30% zero values in the sparse matrix. // -// tensor: The input tensor. -// output_handle: The list. -func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { +// The gradient computation of this operation will only take advantage of sparsity +// in the input gradient when that gradient comes from a Relu. +func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorListFromTensor", + Type: "SparseMatMul", Input: []tf.Input{ - tensor, element_shape, + a, b, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. -type GenerateVocabRemappingAttr func(optionalAttr) - -// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. +// Concatenates quantized tensors along one dimension. // -// value: Number of entries in the old vocab file to consider. If -1, -// use the entire old vocabulary. -// If not specified, defaults to -1 +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// input_mins: The minimum scalar values for each of the input tensors. +// input_maxes: The maximum scalar values for each of the input tensors. // -// REQUIRES: value >= -1 -func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { - return func(m optionalAttr) { - m["old_vocab_size"] = value +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QuantizedConcat", + Input: []tf.Input{ + concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// Given a path to new and old vocabulary files, returns a remapping Tensor of +// Slice a `SparseTensor` based on the `start` and `size`. // -// length `num_new_vocab`, where `remapping[i]` contains the row number in the old -// vocabulary that corresponds to row `i` in the new vocabulary (starting at line -// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` -// in the new vocabulary is not in the old vocabulary. The old vocabulary is -// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the -// default value of -1. +// For example, if the input is // -// `num_vocab_offset` enables -// use in the partitioned variable case, and should generally be set through -// examining partitioning info. The format of the files should be a text file, -// with each line containing a single entity within the vocabulary. +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] // -// For example, with `new_vocab_file` a text file containing each of the following -// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], -// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be -// `[0, -1, 2]`. +// Graphically the output tensors are: // -// The op also returns a count of how many entries in the new vocabulary -// were present in the old vocabulary, which is used to calculate the number of -// values to initialize in a weight matrix remapping +// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] +// [ a ] +// [b c ] // -// This functionality can be used to remap both row vocabularies (typically, -// features) and column vocabularies (typically, classes) from TensorFlow -// checkpoints. Note that the partitioning logic relies on contiguous vocabularies -// corresponding to div-partitioned variables. Moreover, the underlying remapping -// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should -// use the corresponding index_table_from_file() as the FeatureColumn framework -// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). +// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] +// [ d e ] +// [ ] // // Arguments: -// new_vocab_file: Path to the new vocab file. -// old_vocab_file: Path to the old vocab file. -// new_vocab_offset: How many entries into the new vocab file to start reading. -// num_new_vocab: Number of entries in the new vocab file to remap. +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// start: 1-D. tensor represents the start of the slice. +// size: 1-D. tensor represents the size of the slice. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. // -// Returns A Tensor of length num_new_vocab where the element at index i -// is equal to the old ID that maps to the new ID i. This element is -1 for any -// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. -func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "GenerateVocabRemapping", + Type: "SparseSlice", Input: []tf.Input{ - new_vocab_file, old_vocab_file, + indices, values, shape, start, size, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// Assigns sparse updates to the variable referenced by `resource`. +// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. // // This operation computes // // # Scalar indices -// ref[indices, ...] = updates[...] +// ref[indices, ...] = min(ref[indices, ...], updates[...]) // // # Vector indices (for each i) -// ref[indices[i], ...] = updates[i, ...] +// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) // // # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] +// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions are combined. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
// // Arguments: // resource: Should be from a `Variable` node. @@ -9142,12 +9122,12 @@ func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_fi // updates: A tensor of updated values to add to `ref`. // // Returns the created operation. -func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ResourceScatterUpdate", + Type: "ResourceScatterMin", Input: []tf.Input{ resource, indices, updates, }, @@ -9155,867 +9135,945 @@ func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, return scope.AddOperation(opspec) } -// Creates and returns an empty tensor list. +// Reshapes a quantized tensor as per the Reshape op. // -// All list elements must be tensors of dtype element_dtype and shape compatible -// with element_shape. +// ``` // -// handle: an empty tensor list. -// element_dtype: the type of elements in the list. -// element_shape: a shape compatible with that of elements in the list. -func EmptyTensorList(scope *Scope, element_shape tf.Output, element_dtype tf.DataType) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "EmptyTensorList", +// Arguments: +// +// shape: Defines the shape of the output tensor. +// input_min: The minimum value of the input. +// input_max: The maximum value of the input. +// +// Returns This value is copied from input_min.This value is copied from input_max. +func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QuantizedReshape", Input: []tf.Input{ - element_shape, + tensor, shape, input_min, input_max, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// AvgPoolGradAttr is an optional argument to AvgPoolGrad. -type AvgPoolGradAttr func(optionalAttr) - -// AvgPoolGradDataFormat sets the optional data_format attribute to value. +// Returns the truth value of (x != y) element-wise. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { - return func(m optionalAttr) { - m["data_format"] = value +// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NotEqual", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes gradients of the average pooling function. +// Inverse 3D real-valued fast Fourier transform. +// +// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 3 dimensions of `input`. +// +// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 3 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along each axis `IRFFT3D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. -// the output of `avg_pool`. -// ksize: The size of the sliding window for each dimension of the input. -// strides: The stride of the sliding window for each dimension of the input. -// padding: The type of padding algorithm to use. +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. // -// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. -func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { +// Returns A float32 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 3D real Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.irfftn with 3 dimensions. +// @end_compatibility +func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "AvgPoolGrad", + Type: "IRFFT3D", Input: []tf.Input{ - orig_input_shape, grad, + input, fft_length, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StageClearAttr is an optional argument to StageClear. -type StageClearAttr func(optionalAttr) +// StringSplitAttr is an optional argument to StringSplit. +type StringSplitAttr func(optionalAttr) -// StageClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// StringSplitSkipEmpty sets the optional skip_empty attribute to value. // -// REQUIRES: value >= 0 -func StageClearCapacity(value int64) StageClearAttr { +// value: A `bool`. If `True`, skip the empty strings from the result. +// If not specified, defaults to true +func StringSplitSkipEmpty(value bool) StringSplitAttr { return func(m optionalAttr) { - m["capacity"] = value + m["skip_empty"] = value } } -// StageClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Split elements of `input` based on `delimiter` into a `SparseTensor`. // -// REQUIRES: value >= 0 -func StageClearMemoryLimit(value int64) StageClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// StageClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StageClearContainer(value string) StageClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StageClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StageClearSharedName(value string) StageClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. +// Let N be the size of source (typically N will be the batch size). Split each +// element of `input` based on `delimiter` and return a `SparseTensor` +// containing the splitted tokens. Empty tokens are ignored. // -// Returns the created operation. -func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { +// `delimiter` can be empty, or a string of split characters. If `delimiter` is an +// empty string, each element of `input` is split into individual single-byte +// character strings, including splitting of UTF-8 multibyte sequences. Otherwise +// every character of `delimiter` is a potential split point. +// +// For example: +// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output +// will be +// +// indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// shape = [2, 3] +// values = ['hello', 'world', 'a', 'b', 'c'] +// +// Arguments: +// input: 1-D. Strings to split. +// delimiter: 0-D. Delimiter characters (bytes), or empty string. +// +// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse +// tensor, where the first value is N and the second value is the maximum number +// of tokens in a single input entry. +func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StageClear", - + Type: "StringSplit", + Input: []tf.Input{ + input, delimiter, + }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. -type ComputeAccidentalHitsAttr func(optionalAttr) +// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. +type ResourceSparseApplyMomentumAttr func(optionalAttr) -// ComputeAccidentalHitsSeed sets the optional seed attribute to value. +// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { return func(m optionalAttr) { - m["seed"] = value + m["use_locking"] = value } } -// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. +// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { return func(m optionalAttr) { - m["seed2"] = value + m["use_nesterov"] = value } } -// Computes the ids of the positions in sampled_candidates that match true_labels. +// Update relevant entries in '*var' and '*accum' according to the momentum scheme. // -// When doing log-odds NCE, the result of this op should be passed through a -// SparseToDense op, then added to the logits of the sampled candidates. This has -// the effect of 'removing' the sampled labels that match the true labels by -// making the classifier sure that they are sampled labels. +// Set use_nesterov = True if you want to use Nesterov momentum. +// +// That is for rows we have grad for, we update var and accum as follows: +// +// accum = accum * momentum + grad +// var -= lr * accum // // Arguments: -// true_classes: The true_classes output of UnpackSparseLabels. -// sampled_candidates: The sampled_candidates output of CandidateSampler. -// num_true: Number of true labels per context. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// momentum: Momentum. Must be a scalar. // -// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label -// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element -// is -FLOAT_MAX. -func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { +// Returns the created operation. +func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ComputeAccidentalHits", + Type: "ResourceSparseApplyMomentum", Input: []tf.Input{ - true_classes, sampled_candidates, + var_, accum, lr, grad, indices, momentum, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. -type QuantizedRelu6Attr func(optionalAttr) - -// QuantizedRelu6OutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { - return func(m optionalAttr) { - m["out_type"] = value - } + return scope.AddOperation(opspec) } -// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` +// Returns the complex conjugate of a complex number. // -// Arguments: +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// complex numbers that are the complex conjugate of each element in `input`. The +// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the +// real part and *b* is the imaginary part. // -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. +// The complex conjugate returned by this operation is of the form \\(a - bj\\). // -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] +// ``` +func Conj(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "QuantizedRelu6", + Type: "Conj", Input: []tf.Input{ - features, min_features, max_features, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. -type FixedLengthRecordReaderV2Attr func(optionalAttr) - -// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. -// -// value: Number of bytes in the header, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["header_bytes"] = value - } -} - -// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. -// -// value: Number of bytes in the footer, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["footer_bytes"] = value - } -} - -// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. -// -// value: Number of bytes to hop before each read. Default of 0 means using -// record_bytes. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["hop_bytes"] = value - } + return op.Output(0) } -// FixedLengthRecordReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} +// ResizeBilinearAttr is an optional argument to ResizeBilinear. +type ResizeBilinearAttr func(optionalAttr) -// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. +// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["align_corners"] = value } } -// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. +// Resize `images` to `size` using bilinear interpolation. // -// value: The type of encoding for the file. Currently ZLIB and GZIP -// are supported. Defaults to none. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["encoding"] = value - } -} - -// A Reader that outputs fixed-length records from a file. +// Input images can be of different types but output images are always float. // // Arguments: -// record_bytes: Number of bytes in the record. +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// Returns The handle to reference the Reader. -func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"record_bytes": record_bytes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FixedLengthRecordReaderV2", - + Type: "ResizeBilinear", + Input: []tf.Input{ + images, size, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. -// -// Note that the hash function may change from time to time. -// This functionality will be deprecated and it's recommended to use -// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. -// -// Arguments: -// -// num_buckets: The number of buckets. -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { +// Computes softsign: `features / (abs(features) + 1)`. +func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "StringToHashBucket", + Type: "Softsign", Input: []tf.Input{ - string_tensor, + features, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes gradients for the exponential linear (Elu) operation. +// Creates a TensorList which, when stacked, has the value of `tensor`. // -// Arguments: -// gradients: The backpropagated gradients to the corresponding Elu operation. -// outputs: The outputs of the corresponding Elu operation. +// Each tensor in the result list corresponds to one row of the input tensor. // -// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, -// `gradients` otherwise. -func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { +// tensor: The input tensor. +// output_handle: The list. +func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "EluGrad", + Type: "TensorListFromTensor", Input: []tf.Input{ - gradients, outputs, + tensor, element_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that contains `count` elements from the `input_dataset`. +// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. +type GenerateVocabRemappingAttr func(optionalAttr) + +// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. // -// Arguments: +// value: Number of entries in the old vocab file to consider. If -1, +// use the entire old vocabulary. +// If not specified, defaults to -1 // -// count: A scalar representing the number of elements from the `input_dataset` -// that should be taken. A value of `-1` indicates that all of `input_dataset` -// is taken. +// REQUIRES: value >= -1 +func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { + return func(m optionalAttr) { + m["old_vocab_size"] = value + } +} + +// Given a path to new and old vocabulary files, returns a remapping Tensor of +// +// length `num_new_vocab`, where `remapping[i]` contains the row number in the old +// vocabulary that corresponds to row `i` in the new vocabulary (starting at line +// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` +// in the new vocabulary is not in the old vocabulary. The old vocabulary is +// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the +// default value of -1. // +// `num_vocab_offset` enables +// use in the partitioned variable case, and should generally be set through +// examining partitioning info. The format of the files should be a text file, +// with each line containing a single entity within the vocabulary. // -func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// For example, with `new_vocab_file` a text file containing each of the following +// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], +// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be +// `[0, -1, 2]`. +// +// The op also returns a count of how many entries in the new vocabulary +// were present in the old vocabulary, which is used to calculate the number of +// values to initialize in a weight matrix remapping +// +// This functionality can be used to remap both row vocabularies (typically, +// features) and column vocabularies (typically, classes) from TensorFlow +// checkpoints. Note that the partitioning logic relies on contiguous vocabularies +// corresponding to div-partitioned variables. Moreover, the underlying remapping +// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should +// use the corresponding index_table_from_file() as the FeatureColumn framework +// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). +// +// Arguments: +// new_vocab_file: Path to the new vocab file. +// old_vocab_file: Path to the old vocab file. +// new_vocab_offset: How many entries into the new vocab file to start reading. +// num_new_vocab: Number of entries in the new vocab file to remap. +// +// Returns A Tensor of length num_new_vocab where the element at index i +// is equal to the old ID that maps to the new ID i. This element is -1 for any +// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. +func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TakeDataset", + Type: "GenerateVocabRemapping", Input: []tf.Input{ - input_dataset, count, + new_vocab_file, old_vocab_file, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// The gradient operator for the SparseAdd op. +// Assigns sparse updates to the variable referenced by `resource`. // -// The SparseAdd op calculates A + B, where A, B, and the sum are all represented -// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. -// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty -// values of A and B. +// This operation computes +// +// # Scalar indices +// ref[indices, ...] = updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] = updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] // // Arguments: -// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to -// the non-empty values of the sum. -// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. -// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. -// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size -// `[nnz(sum), ndims]`. +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. // -// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the -// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the -// non-empty values of B. -func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { +// Returns the created operation. +func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseAddGrad", + Type: "ResourceScatterUpdate", Input: []tf.Input{ - backprop_val_grad, a_indices, b_indices, sum_indices, + resource, indices, updates, }, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Computes atan of x element-wise. -func Atan(scope *Scope, x tf.Output) (y tf.Output) { +// Creates and returns an empty tensor list. +// +// All list elements must be tensors of dtype element_dtype and shape compatible +// with element_shape. +// +// handle: an empty tensor list. +// element_dtype: the type of elements in the list. +// element_shape: a shape compatible with that of elements in the list. +func EmptyTensorList(scope *Scope, element_shape tf.Output, element_dtype tf.DataType) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "Atan", + Type: "EmptyTensorList", Input: []tf.Input{ - x, + element_shape, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Encode audio data using the WAV file format. -// -// This operation will generate a string suitable to be saved out to create a .wav -// audio file. It will be encoded in the 16-bit PCM format. It takes in float -// values in the range -1.0f to 1.0f, and any outside that value will be clamped to -// that range. +// AvgPoolGradAttr is an optional argument to AvgPoolGrad. +type AvgPoolGradAttr func(optionalAttr) + +// AvgPoolGradDataFormat sets the optional data_format attribute to value. // -// `audio` is a 2-D float Tensor of shape `[length, channels]`. -// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of the average pooling function. // // Arguments: -// audio: 2-D with shape `[length, channels]`. -// sample_rate: Scalar containing the sample frequency. +// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. +// the output of `avg_pool`. +// ksize: The size of the sliding window for each dimension of the input. +// strides: The stride of the sliding window for each dimension of the input. +// padding: The type of padding algorithm to use. // -// Returns 0-D. WAV-encoded file contents. -func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { +// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. +func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "EncodeWav", + Type: "AvgPoolGrad", Input: []tf.Input{ - audio, sample_rate, + orig_input_shape, grad, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. The hash function is a keyed hash function, where attribute `key` -// defines the key of the hash function. `key` is an array of 2 elements. +// StageClearAttr is an optional argument to StageClear. +type StageClearAttr func(optionalAttr) + +// StageClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// A strong hash is important when inputs may be malicious, e.g. URLs with -// additional components. Adversaries could try to make their inputs hash to the -// same bucket for a denial-of-service attack or to skew the results. A strong -// hash prevents this by making it difficult, if not infeasible, to compute inputs -// that hash to the same bucket. This comes at a cost of roughly 4x higher compute -// time than `tf.string_to_hash_bucket_fast`. +// REQUIRES: value >= 0 +func StageClearCapacity(value int64) StageClearAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// StageClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. -// key: The key for the keyed hash function passed as a list of two uint64 -// elements. +// REQUIRES: value >= 0 +func StageClearMemoryLimit(value int64) StageClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// StageClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StageClearContainer(value string) StageClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// StageClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StageClearSharedName(value string) StageClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { +// Returns the created operation. +func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "StringToHashBucketStrong", - Input: []tf.Input{ - input, - }, + Type: "StageClear", + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RegexReplaceAttr is an optional argument to RegexReplace. -type RegexReplaceAttr func(optionalAttr) +// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. +type ComputeAccidentalHitsAttr func(optionalAttr) -// RegexReplaceReplaceGlobal sets the optional replace_global attribute to value. +// ComputeAccidentalHitsSeed sets the optional seed attribute to value. // -// value: If True, the replacement is global, otherwise the replacement -// is done only on the first match. -// If not specified, defaults to true -func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { return func(m optionalAttr) { - m["replace_global"] = value + m["seed"] = value } } -// Replaces the match of pattern in input with rewrite. +// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. // -// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Computes the ids of the positions in sampled_candidates that match true_labels. +// +// When doing log-odds NCE, the result of this op should be passed through a +// SparseToDense op, then added to the logits of the sampled candidates. This has +// the effect of 'removing' the sampled labels that match the true labels by +// making the classifier sure that they are sampled labels. // // Arguments: -// input: The text to be processed. -// pattern: The regular expression to match the input. -// rewrite: The rewrite to be applied to the matched expresion. +// true_classes: The true_classes output of UnpackSparseLabels. +// sampled_candidates: The sampled_candidates output of CandidateSampler. +// num_true: Number of true labels per context. // -// Returns The text after applying pattern and rewrite. -func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.Output, optional ...RegexReplaceAttr) (output tf.Output) { +// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label +// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element +// is -FLOAT_MAX. +func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RegexReplace", + Type: "ComputeAccidentalHits", Input: []tf.Input{ - input, pattern, rewrite, + true_classes, sampled_candidates, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes numerical negative value element-wise. -// -// I.e., \\(y = -x\\). -func Neg(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Neg", - Input: []tf.Input{ - x, - }, +// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. +type QuantizedRelu6Attr func(optionalAttr) + +// QuantizedRelu6OutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { + return func(m optionalAttr) { + m["out_type"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Execute a sub graph on a remote processor. -// -// The graph specifications(such as graph itself, input tensors and output names) -// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo -// as serialized_remote_fused_graph_execute_info. -// The specifications will be passed to a dedicated registered -// remote fused graph executor. The executor will send the graph specifications -// to a remote processor and execute that graph. The execution results -// will be passed to consumer nodes as outputs of this node. +// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` // // Arguments: -// inputs: Arbitrary number of tensors with arbitrary data types // -// serialized_remote_fused_graph_execute_info: Serialized protocol buffer -// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. // -// Returns Arbitrary number of tensors with arbitrary data types -func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RemoteFusedGraphExecute", + Type: "QuantizedRelu6", Input: []tf.Input{ - tf.OutputList(inputs), + features, min_features, max_features, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("RemoteFusedGraphExecute", err) - return - } - return outputs + return op.Output(0), op.Output(1), op.Output(2) } -// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. -type MaxPool3DGradGradAttr func(optionalAttr) +// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. +type FixedLengthRecordReaderV2Attr func(optionalAttr) -// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. +// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { +// value: Number of bytes in the header, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["header_bytes"] = value } } -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. // -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) +// value: Number of bytes in the footer, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["footer_bytes"] = value } - opspec := tf.OpSpec{ - Type: "MaxPool3DGradGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, +} + +// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. +// +// value: Number of bytes to hop before each read. Default of 0 means using +// record_bytes. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["hop_bytes"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. -type Conv3DBackpropFilterV2Attr func(optionalAttr) +// FixedLengthRecordReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} -// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. +// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["shared_name"] = value } } -// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. +// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. // -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { +// value: The type of encoding for the file. Currently ZLIB and GZIP +// are supported. Defaults to none. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["encoding"] = value } } -// Computes the gradients of 3-D convolution with respect to the filter. +// A Reader that outputs fixed-length records from a file. // // Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 5-D -// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` -// tensor. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { +// record_bytes: Number of bytes in the record. +// +// Returns The handle to reference the Reader. +func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"record_bytes": record_bytes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilterV2", - Input: []tf.Input{ - input, filter_sizes, out_backprop, - }, + Type: "FixedLengthRecordReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. -type FakeQuantWithMinMaxVarsAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["num_bits"] = value +// The gradient operator for the SparseAdd op. +// +// The SparseAdd op calculates A + B, where A, B, and the sum are all represented +// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. +// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty +// values of A and B. +// +// Arguments: +// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to +// the non-empty values of the sum. +// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. +// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. +// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size +// `[nnz(sum), ndims]`. +// +// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the +// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the +// non-empty values of B. +func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseAddGrad", + Input: []tf.Input{ + backprop_val_grad, a_indices, b_indices, sum_indices, + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["narrow_range"] = value +// Computes atan of x element-wise. +func Atan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atan", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` +// Encode audio data using the WAV file format. // -// and `max` to 'outputs' tensor of same shape as `inputs`. +// This operation will generate a string suitable to be saved out to create a .wav +// audio file. It will be encoded in the 16-bit PCM format. It takes in float +// values in the range -1.0f to 1.0f, and any outside that value will be clamped to +// that range. // -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. +// `audio` is a 2-D float Tensor of shape `[length, channels]`. +// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). // -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { +// Arguments: +// audio: 2-D with shape `[length, channels]`. +// sample_rate: Scalar containing the sample frequency. +// +// Returns 0-D. WAV-encoded file contents. +func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVars", + Type: "EncodeWav", Input: []tf.Input{ - inputs, min, max, + audio, sample_rate, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Applies softmax to a batched N-D `SparseTensor`. -// -// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` -// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. -// -// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost -// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly -// zero elements do not participate*. Specifically, the algorithm is equivalent -// to the following: +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix -// with shape `[B, C]`, along the size-C dimension; -// (2) Masks out the original implicitly-zero locations; -// (3) Renormalizes the remaining elements. +// The hash function is deterministic on the content of the string within the +// process. The hash function is a keyed hash function, where attribute `key` +// defines the key of the hash function. `key` is an array of 2 elements. // -// Hence, the `SparseTensor` result has exactly the same non-zero indices and -// shape. +// A strong hash is important when inputs may be malicious, e.g. URLs with +// additional components. Adversaries could try to make their inputs hash to the +// same bucket for a denial-of-service attack or to skew the results. A strong +// hash prevents this by making it difficult, if not infeasible, to compute inputs +// that hash to the same bucket. This comes at a cost of roughly 4x higher compute +// time than `tf.string_to_hash_bucket_fast`. // // Arguments: -// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a -// SparseTensor, in canonical ordering. -// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// key: The key for the keyed hash function passed as a list of two uint64 +// elements. // -// Returns 1-D. The `NNZ` values for the result `SparseTensor`. -func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} opspec := tf.OpSpec{ - Type: "SparseSoftmax", + Type: "StringToHashBucketStrong", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Partitions `data` into `num_partitions` tensors using indices from `partitions`. -// -// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` -// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` -// are placed in `outputs[i]` in lexicographic order of `js`, and the first -// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. -// In detail, -// -// ```python -// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] -// -// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) -// ``` +// RegexReplaceAttr is an optional argument to RegexReplace. +type RegexReplaceAttr func(optionalAttr) + +// RegexReplaceReplaceGlobal sets the optional replace_global attribute to value. // -// `data.shape` must start with `partitions.shape`. +// value: If True, the replacement is global, otherwise the replacement +// is done only on the first match. +// If not specified, defaults to true +func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { + return func(m optionalAttr) { + m["replace_global"] = value + } +} + +// Replaces the match of pattern in input with rewrite. // -// For example: +// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) // -// ```python -// # Scalar partitions. -// partitions = 1 -// num_partitions = 2 -// data = [10, 20] -// outputs[0] = [] # Empty with shape [0, 2] -// outputs[1] = [[10, 20]] +// Arguments: +// input: The text to be processed. +// pattern: The regular expression to match the input. +// rewrite: The rewrite to be applied to the matched expresion. // -// # Vector partitions. -// partitions = [0, 0, 1, 1, 0] -// num_partitions = 2 -// data = [10, 20, 30, 40, 50] -// outputs[0] = [10, 20, 50] -// outputs[1] = [30, 40] -// ``` +// Returns The text after applying pattern and rewrite. +func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.Output, optional ...RegexReplaceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RegexReplace", + Input: []tf.Input{ + input, pattern, rewrite, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes numerical negative value element-wise. // -// See `dynamic_stitch` for an example on how to merge partitions back. +// I.e., \\(y = -x\\). +func Neg(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Neg", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Execute a sub graph on a remote processor. // -//
-// -//
+// The graph specifications(such as graph itself, input tensors and output names) +// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo +// as serialized_remote_fused_graph_execute_info. +// The specifications will be passed to a dedicated registered +// remote fused graph executor. The executor will send the graph specifications +// to a remote processor and execute that graph. The execution results +// will be passed to consumer nodes as outputs of this node. // // Arguments: +// inputs: Arbitrary number of tensors with arbitrary data types // -// partitions: Any shape. Indices in the range `[0, num_partitions)`. -// num_partitions: The number of partitions to output. -func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { +// serialized_remote_fused_graph_execute_info: Serialized protocol buffer +// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// +// Returns Arbitrary number of tensors with arbitrary data types +func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_partitions": num_partitions} + attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} opspec := tf.OpSpec{ - Type: "DynamicPartition", + Type: "RemoteFusedGraphExecute", Input: []tf.Input{ - data, partitions, + tf.OutputList(inputs), }, Attrs: attrs, } @@ -10026,127 +10084,117 @@ func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_pa var idx int var err error if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("DynamicPartition", err) + scope.UpdateErr("RemoteFusedGraphExecute", err) return } return outputs } -// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. -type ResourceApplyAdagradAttr func(optionalAttr) +// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. +type MaxPool3DGradGradAttr func(optionalAttr) -// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. +// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. -// If not specified, defaults to true -func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { return func(m optionalAttr) { - m["update_slots"] = value + m["data_format"] = value } } -// Update '*var' according to the adagrad scheme. -// -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// Computes second-order gradients of the maxpooling function. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns the created operation. -func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdagrad", + Type: "MaxPool3DGradGrad", Input: []tf.Input{ - var_, accum, lr, grad, + orig_input, orig_output, grad, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Return the shape of s0 op s1 with broadcast. -// -// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the -// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. -func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BroadcastArgs", - Input: []tf.Input{ - s0, s1, - }, - } op := scope.AddOperation(opspec) return op.Output(0) } -// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. -type DataFormatDimMapAttr func(optionalAttr) +// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. +type Conv3DBackpropFilterV2Attr func(optionalAttr) -// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. +// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. // -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { - m["src_format"] = value + m["data_format"] = value } } -// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. +// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. // -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { - m["dst_format"] = value + m["dilations"] = value } } -// Returns the dimension index in the destination data format given the one in -// -// the source data format. +// Computes the gradients of 3-D convolution with respect to the filter. // // Arguments: -// x: A Tensor with each element as a dimension index in source data format. -// Must be in the range [-4, 4). -// -// Returns A Tensor with each element as a dimension index in destination data format. -func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 5-D +// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` +// tensor. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DataFormatDimMap", + Type: "Conv3DBackpropFilterV2", Input: []tf.Input{ - x, + input, filter_sizes, out_backprop, }, Attrs: attrs, } @@ -10154,38 +10202,38 @@ func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAtt return op.Output(0) } -// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. -type ResourceApplyPowerSignAttr func(optionalAttr) +// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. +type FakeQuantWithMinMaxVarsAttr func(optionalAttr) -// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { +// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["num_bits"] = value } } -// Update '*var' according to the AddSign update. +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g -// variable <- variable - lr_t * update +// and `max` to 'outputs' tensor of same shape as `inputs`. // -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// logbase: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. -// grad: The gradient. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. // -// Returns the created operation. -func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return } @@ -10194,161 +10242,160 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyPowerSign", + Type: "FakeQuantWithMinMaxVars", Input: []tf.Input{ - var_, m, lr, logbase, sign_decay, beta, grad, + inputs, min, max, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Locks a mutex resource. The output is the lock. So long as the lock tensor -// -// is alive, any other request to use `MutexLock` with this mutex will wait. -// -// This is particularly useful for creating a critical section when used in -// conjunction with `MutexLockIdentity`: -// -// ```python -// -// mutex = mutex_v2( -// shared_name=handle_name, container=container, name=name) -// -// def execute_in_critical_section(fn, *args, **kwargs): -// lock = gen_resource_variable_ops.mutex_lock(mutex) -// -// with ops.control_dependencies([lock]): -// r = fn(*args, **kwargs) -// -// with ops.control_dependencies(nest.flatten(r)): -// with ops.colocate_with(mutex): -// ensure_lock_exists = mutex_lock_identity(lock) -// -// # Make sure that if any element of r is accessed, all of -// # them are executed together. -// r = nest.map_structure(tf.identity, r) +// Applies softmax to a batched N-D `SparseTensor`. // -// with ops.control_dependencies([ensure_lock_exists]): -// return nest.map_structure(tf.identity, r) -// ``` +// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` +// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. // -// While `fn` is running in the critical section, no other functions which wish to -// use this critical section may run. +// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost +// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly +// zero elements do not participate*. Specifically, the algorithm is equivalent +// to the following: // -// Often the use case is that two executions of the same graph, in parallel, -// wish to run `fn`; and we wish to ensure that only one of them executes -// at a time. This is especially important if `fn` modifies one or more -// variables at a time. +// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix +// with shape `[B, C]`, along the size-C dimension; +// (2) Masks out the original implicitly-zero locations; +// (3) Renormalizes the remaining elements. // -// It is also useful if two separate functions must share a resource, but we -// wish to ensure the usage is exclusive. +// Hence, the `SparseTensor` result has exactly the same non-zero indices and +// shape. // // Arguments: -// mutex: The mutex resource to lock. +// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a +// SparseTensor, in canonical ordering. +// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. // -// Returns A tensor that keeps a shared pointer to a lock on the mutex; -// when the Tensor is destroyed, the use count on the shared pointer is decreased -// by 1. When it reaches 0, the lock is released. -func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { +// Returns 1-D. The `NNZ` values for the result `SparseTensor`. +func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MutexLock", + Type: "SparseSoftmax", Input: []tf.Input{ - mutex, + sp_indices, sp_values, sp_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the mean along segments of a tensor. +// Partitions `data` into `num_partitions` tensors using indices from `partitions`. // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` +// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` +// are placed in `outputs[i]` in lexicographic order of `js`, and the first +// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. +// In detail, // -// Computes a tensor such that -// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is -// over `j` such that `segment_ids[j] == i` and `N` is the total number of -// values summed. +// ```python +// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] // -// If the mean is empty for a given segment ID `i`, `output[i] = 0`. +// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) +// ``` +// +// `data.shape` must start with `partitions.shape`. +// +// For example: +// +// ```python +// # Scalar partitions. +// partitions = 1 +// num_partitions = 2 +// data = [10, 20] +// outputs[0] = [] # Empty with shape [0, 2] +// outputs[1] = [[10, 20]] +// +// # Vector partitions. +// partitions = [0, 0, 1, 1, 0] +// num_partitions = 2 +// data = [10, 20, 30, 40, 50] +// outputs[0] = [10, 20, 50] +// outputs[1] = [30, 40] +// ``` +// +// See `dynamic_stitch` for an example on how to merge partitions back. // //
-// +// //
// // Arguments: // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// partitions: Any shape. Indices in the range `[0, num_partitions)`. +// num_partitions: The number of partitions to output. +func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_partitions": num_partitions} opspec := tf.OpSpec{ - Type: "SegmentMean", + Type: "DynamicPartition", Input: []tf.Input{ - data, segment_ids, + data, partitions, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("DynamicPartition", err) + return + } + return outputs } -// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. -type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) +// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. +type ResourceApplyAdagradAttr func(optionalAttr) -// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. +// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { +func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the centered RMSProp algorithm. -// -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. -// -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update '*var' according to the adagrad scheme. // -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) // // Arguments: // var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). +// accum: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { +func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10357,174 +10404,116 @@ func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyCenteredRMSProp", + Type: "ResourceApplyAdagrad", Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, + var_, accum, lr, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Creates a dataset that batches `batch_size` elements from `input_dataset`. -// -// Arguments: -// -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// +// Return the shape of s0 op s1 with broadcast. // -func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the +// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. +func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "BatchDataset", + Type: "BroadcastArgs", Input: []tf.Input{ - input_dataset, batch_size, + s0, s1, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. +// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. +type DataFormatDimMapAttr func(optionalAttr) + +// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. // -// More formally, let +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { + return func(m optionalAttr) { + m["src_format"] = value + } +} + +// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. // -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { + return func(m optionalAttr) { + m["dst_format"] = value + } +} + +// Returns the dimension index in the destination data format given the one in // -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// the source data format. // // Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. +// x: A Tensor with each element as a dimension index in source data format. +// Must be in the range [-4, 4). // -// Returns Computed precision at `k` as a `bool Tensor`. -func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { +// Returns A Tensor with each element as a dimension index in destination data format. +func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "InTopKV2", + Type: "DataFormatDimMap", Input: []tf.Input{ - predictions, targets, k, + x, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. -type DecodeAndCropJpegAttr func(optionalAttr) - -// DecodeAndCropJpegChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeAndCropJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} +// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. +type ResourceApplyPowerSignAttr func(optionalAttr) -// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. // -// value: If true try to recover an image from truncated input. +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } -} - -// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. -// -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value - } -} - -// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. -// -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { +func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { return func(m optionalAttr) { - m["dct_method"] = value + m["use_locking"] = value } } -// Decode and Crop a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. -// -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. -// +// Update '*var' according to the AddSign update. // -// It is equivalent to a combination of decode and crop, but much faster by only -// decoding partial jpeg image. +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g +// variable <- variable - lr_t * update // // Arguments: -// contents: 0-D. The JPEG-encoded image. -// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// logbase: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. // -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { +// Returns the created operation. +func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10533,341 +10522,337 @@ func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeAndCropJpeg", + Type: "ResourceApplyPowerSign", Input: []tf.Input{ - contents, crop_window, + var_, m, lr, logbase, sign_decay, beta, grad, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. -type AllCandidateSamplerAttr func(optionalAttr) - -// AllCandidateSamplerSeed sets the optional seed attribute to value. +// Locks a mutex resource. The output is the lock. So long as the lock tensor // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// is alive, any other request to use `MutexLock` with this mutex will wait. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Generates labels for candidate sampling with a learned unigram distribution. +// This is particularly useful for creating a critical section when used in +// conjunction with `MutexLockIdentity`: // -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. +// ```python // -// For each batch, this op picks a single set of sampled candidate labels. +// mutex = mutex_v2( +// shared_name=handle_name, container=container, name=name) // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// def execute_in_critical_section(fn, *args, **kwargs): +// lock = gen_resource_variable_ops.mutex_lock(mutex) +// +// with ops.control_dependencies([lock]): +// r = fn(*args, **kwargs) +// +// with ops.control_dependencies(nest.flatten(r)): +// with ops.colocate_with(mutex): +// ensure_lock_exists = mutex_lock_identity(lock) +// +// # Make sure that if any element of r is accessed, all of +// # them are executed together. +// r = nest.map_structure(tf.identity, r) +// +// with ops.control_dependencies([ensure_lock_exists]): +// return nest.map_structure(tf.identity, r) +// ``` +// +// While `fn` is running in the critical section, no other functions which wish to +// use this critical section may run. +// +// Often the use case is that two executions of the same graph, in parallel, +// wish to run `fn`; and we wish to ensure that only one of them executes +// at a time. This is especially important if `fn` modifies one or more +// variables at a time. +// +// It is also useful if two separate functions must share a resource, but we +// wish to ensure the usage is exclusive. // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to produce. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. +// mutex: The mutex resource to lock. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns A tensor that keeps a shared pointer to a lock on the mutex; +// when the Tensor is destroyed, the use count on the shared pointer is decreased +// by 1. When it reaches 0, the lock is released. +func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "AllCandidateSampler", + Type: "MutexLock", Input: []tf.Input{ - true_classes, + mutex, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Adds two `SparseTensor` objects to produce another `SparseTensor`. +// Computes the mean along segments of a tensor. // -// The input `SparseTensor` objects' indices are assumed ordered in standard -// lexicographic order. If this is not the case, before this step run -// `SparseReorder` to restore index ordering. +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // -// By default, if two values sum to zero at some index, the output `SparseTensor` -// would still include that particular location in its index, storing a zero in the -// corresponding value slot. To override this, callers can specify `thresh`, -// indicating that if the sum has a magnitude strictly smaller than `thresh`, its -// corresponding value and index would then not be included. In particular, -// `thresh == 0` (default) means everything is kept and actual thresholding happens -// only for a positive value. +// Computes a tensor such that +// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is +// over `j` such that `segment_ids[j] == i` and `N` is the total number of +// values summed. // -// In the following shapes, `nnz` is the count after taking `thresh` into account. +// If the mean is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
// // Arguments: -// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. -// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. -// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. -// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. -// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. -// thresh: 0-D. The magnitude threshold that determines if an output value/index -// pair takes space. -func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { +// +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseAdd", + Type: "SegmentMean", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, + data, segment_ids, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. -type OrderedMapPeekAttr func(optionalAttr) +// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. +type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) -// OrderedMapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { return func(m optionalAttr) { - m["capacity"] = value + m["use_locking"] = value } } -// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Update '*var' according to the centered RMSProp algorithm. // -// REQUIRES: value >= 0 -func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op peeks at the values at the specified key. If the +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. // -// underlying container does not contain this key -// this op will block until it does. This Op is optimized for -// performance. -func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// mean_grad = decay * mean_grad + (1-decay) * gradient +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// mg: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. +// +// Returns the created operation. +func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapPeek", + Type: "ResourceSparseApplyCenteredRMSProp", Input: []tf.Input{ - key, indices, + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapPeek", err) - return - } - return values + return scope.AddOperation(opspec) } -// Inverse fast Fourier transform. -// -// Computes the inverse 1-dimensional discrete Fourier transform over the -// inner-most dimension of `input`. +// Creates a dataset that batches `batch_size` elements from `input_dataset`. // // Arguments: -// input: A complex64 tensor. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its inverse 1D Fourier transform. +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. // -// @compatibility(numpy) -// Equivalent to np.fft.ifft -// @end_compatibility -func IFFT(scope *Scope, input tf.Output) (output tf.Output) { +// +func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IFFT", + Type: "BatchDataset", Input: []tf.Input{ - input, + input_dataset, batch_size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Generates values in an interval. +// Says whether the targets are in the top `K` predictions. // -// A sequence of `num` evenly-spaced values are generated beginning at `start`. -// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, -// so that the last one is exactly `stop`. +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. // -// For example: +// More formally, let // -// ``` -// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] -// ``` +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ // // Arguments: -// start: First entry in the range. -// stop: Last entry in the range. -// num: Number of values to generate. +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. // -// Returns 1-D. The generated values. -func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { +// Returns Computed precision at `k` as a `bool Tensor`. +func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LinSpace", + Type: "InTopKV2", Input: []tf.Input{ - start, stop, num, + predictions, targets, k, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. -type DestroyResourceOpAttr func(optionalAttr) +// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. +type DecodeAndCropJpegAttr func(optionalAttr) -// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. +// DecodeAndCropJpegChannels sets the optional channels attribute to value. // -// value: whether to ignore the error when the resource -// doesn't exist. -// If not specified, defaults to true -func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["ignore_lookup_error"] = value + m["channels"] = value } } -// Deletes the resource specified by the handle. -// -// All subsequent operations using the resource will result in a NotFound -// error status. -// -// Arguments: -// resource: handle to the resource to delete. +// DecodeAndCropJpegRatio sets the optional ratio attribute to value. // -// Returns the created operation. -func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DestroyResourceOp", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value } - return scope.AddOperation(opspec) } -// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. -type ResourceSparseApplyRMSPropAttr func(optionalAttr) - -// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. // -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["fancy_upscaling"] = value } } -// Update '*var' according to the RMSProp algorithm. +// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. // -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. // -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. // -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode and Crop a JPEG-encoded image to a uint8 tensor. // -// Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. +// The attr `channels` indicates the desired number of color channels for the +// decoded image. // -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. +// Accepted values are: // -// Returns the created operation. -func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// +// It is equivalent to a combination of decode and crop, but much faster by only +// decoding partial jpeg image. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { if scope.Err() != nil { return } @@ -10876,179 +10861,81 @@ func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyRMSProp", + Type: "DecodeAndCropJpeg", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + contents, crop_window, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Returns the truth value of (x > y) element-wise. -// -// *NOTE*: `Greater` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Greater", - Input: []tf.Input{ - x, y, - }, - } op := scope.AddOperation(opspec) return op.Output(0) } -// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. -type SampleDistortedBoundingBoxAttr func(optionalAttr) +// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. +type AllCandidateSamplerAttr func(optionalAttr) -// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. +// AllCandidateSamplerSeed sets the optional seed attribute to value. // -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. // If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { +func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { return func(m optionalAttr) { m["seed"] = value } } -// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. +// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// value: A second seed to avoid seed collision. +// value: An second seed to avoid seed collision. // If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { +func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { return func(m optionalAttr) { m["seed2"] = value } } -// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. -// -// value: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. -// If not specified, defaults to 0.1 -func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["min_object_covered"] = value - } -} - -// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. -// -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["aspect_ratio_range"] = value - } -} - -// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. -// -// value: The cropped area of the image must contain a fraction of the -// supplied image within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["area_range"] = value - } -} - -// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. -// -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["max_attempts"] = value - } -} - -// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. -// -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. -// If not specified, defaults to false -func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value - } -} - -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. -// -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. -// -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, -// -// ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) +// Generates labels for candidate sampling with a learned unigram distribution. // -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. // -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) -// ``` +// For each batch, this op picks a single set of sampled candidate labels. // -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to produce. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. // -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBox", + Type: "AllCandidateSampler", Input: []tf.Input{ - image_size, bounding_boxes, + true_classes, }, Attrs: attrs, } @@ -11056,487 +10943,186 @@ func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_box return op.Output(0), op.Output(1), op.Output(2) } -// LRNAttr is an optional argument to LRN. -type LRNAttr func(optionalAttr) - -// LRNDepthRadius sets the optional depth_radius attribute to value. +// Adds two `SparseTensor` objects to produce another `SparseTensor`. // -// value: 0-D. Half-width of the 1-D normalization window. -// If not specified, defaults to 5 -func LRNDepthRadius(value int64) LRNAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } -} - -// LRNBias sets the optional bias attribute to value. -// -// value: An offset (usually positive to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNBias(value float32) LRNAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} - -// LRNAlpha sets the optional alpha attribute to value. -// -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNAlpha(value float32) LRNAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// LRNBeta sets the optional beta attribute to value. -// -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNBeta(value float32) LRNAttr { - return func(m optionalAttr) { - m["beta"] = value - } -} - -// Local Response Normalization. -// -// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last -// dimension), and each vector is normalized independently. Within a given vector, -// each component is divided by the weighted, squared sum of inputs within -// `depth_radius`. In detail, +// The input `SparseTensor` objects' indices are assumed ordered in standard +// lexicographic order. If this is not the case, before this step run +// `SparseReorder` to restore index ordering. // -// sqr_sum[a, b, c, d] = -// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) -// output = input / (bias + alpha * sqr_sum) ** beta +// By default, if two values sum to zero at some index, the output `SparseTensor` +// would still include that particular location in its index, storing a zero in the +// corresponding value slot. To override this, callers can specify `thresh`, +// indicating that if the sum has a magnitude strictly smaller than `thresh`, its +// corresponding value and index would then not be included. In particular, +// `thresh == 0` (default) means everything is kept and actual thresholding happens +// only for a positive value. // -// For details, see [Krizhevsky et al., ImageNet classification with deep -// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// In the following shapes, `nnz` is the count after taking `thresh` into account. // // Arguments: -// input: 4-D. -func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LRN", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that zips together `input_datasets`. -func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. +// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. +// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. +// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. +// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. +// thresh: 0-D. The magnitude threshold that determines if an output value/index +// pair takes space. +func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ZipDataset", + Type: "SparseAdd", Input: []tf.Input{ - tf.OutputList(input_datasets), + a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. -type ResourceSparseApplyAdagradAttr func(optionalAttr) +// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. +type OrderedMapPeekAttr func(optionalAttr) -// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. +// OrderedMapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { +// REQUIRES: value >= 0 +func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["capacity"] = value } } -// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. -// If not specified, defaults to true -func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { +// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { return func(m optionalAttr) { - m["update_slots"] = value + m["memory_limit"] = value } } -// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. -// -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagrad", - Input: []tf.Input{ - var_, accum, lr, grad, indices, - }, - Attrs: attrs, +// OrderedMapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["container"] = value } - return scope.AddOperation(opspec) } -// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. -type StatelessRandomUniformAttr func(optionalAttr) - -// StatelessRandomUniformDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { +// OrderedMapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { return func(m optionalAttr) { - m["dtype"] = value + m["shared_name"] = value } } -// Outputs deterministic pseudorandom random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// Op peeks at the values at the specified key. If the // -// Returns Random values with specified shape. -func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { +// underlying container does not contain this key +// this op will block until it does. This Op is optimized for +// performance. +func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessRandomUniform", + Type: "OrderedMapPeek", Input: []tf.Input{ - shape, seed, + key, indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Makes its input available to the next iteration. -// -// Arguments: -// data: The tensor to be made available to the next iteration. -// -// Returns The same tensor as `data`. -func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NextIteration", - Input: []tf.Input{ - data, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Output a fact about factorials. -func Fact(scope *Scope) (fact tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Fact", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Elementwise computes the bitwise XOR of `x` and `y`. -// -// The result will have those bits set, that are different in `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "BitwiseXor", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deserialize `SparseTensor` objects. -// -// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where -// the last dimension stores serialized `SparseTensor` objects and the other N -// dimensions (N >= 0) correspond to a batch. The ranks of the original -// `SparseTensor` objects must all match. When the final `SparseTensor` is -// created, its rank is the rank of the incoming `SparseTensor` objects plus N; -// the sparse tensors have been concatenated along new dimensions, one for each -// batch. -// -// The output `SparseTensor` object's shape values for the original dimensions -// are the max across the input `SparseTensor` objects' shape values for the -// corresponding dimensions. The new dimensions match the size of the batch. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: -// -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// -// and -// -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// -// then the final deserialized `SparseTensor` will be: -// -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] -// -// Arguments: -// serialized_sparse: The serialized `SparseTensor` objects. The last dimension -// must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { - if scope.Err() != nil { + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapPeek", err) return } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "DeserializeSparse", - Input: []tf.Input{ - serialized_sparse, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. -type ResourceScatterNdUpdateAttr func(optionalAttr) - -// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value. -// -// value: An optional bool. Defaults to True. If True, the assignment will -// be protected by a lock; otherwise the behavior is undefined, -// but may exhibit less contention. -// If not specified, defaults to true -func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } + return values } -// Applies sparse `updates` to individual values or slices within a given -// -// variable according to `indices`. -// -// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `ref`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -// dimension of `ref`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// ``` -// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. -// ``` -// -// For example, say we want to update 4 scattered elements to a rank-1 tensor to -// 8 elements. In Python, that update would look like this: -// -// ```python -// ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1] ,[7]]) -// updates = tf.constant([9, 10, 11, 12]) -// update = tf.scatter_nd_update(ref, indices, updates) -// with tf.Session() as sess: -// print sess.run(update) -// ``` -// -// The resulting update to ref would look like this: -// -// [1, 11, 3, 10, 9, 6, 7, 12] +// Inverse fast Fourier transform. // -// See @{tf.scatter_nd} for more details about how to make updates to -// slices. +// Computes the inverse 1-dimensional discrete Fourier transform over the +// inner-most dimension of `input`. // // Arguments: -// ref: A resource handle. Must be from a VarHandleOp. -// indices: A Tensor. Must be one of the following types: int32, int64. -// A tensor of indices into ref. -// updates: A Tensor. Must have the same type as ref. A tensor of updated -// values to add to ref. -// -// Returns the created operation. -func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceScatterNdUpdate", - Input: []tf.Input{ - ref, indices, updates, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// SqueezeAttr is an optional argument to Squeeze. -type SqueezeAttr func(optionalAttr) - -// SqueezeAxis sets the optional axis attribute to value. -// -// value: If specified, only squeezes the dimensions listed. The dimension -// index starts at 0. It is an error to squeeze a dimension that is not 1. Must -// be in the range `[-rank(input), rank(input))`. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func SqueezeAxis(value []int64) SqueezeAttr { - return func(m optionalAttr) { - m["squeeze_dims"] = value - } -} - -// Removes dimensions of size 1 from the shape of a tensor. -// -// Given a tensor `input`, this operation returns a tensor of the same type with -// all dimensions of size 1 removed. If you don't want to remove all size 1 -// dimensions, you can remove specific size 1 dimensions by specifying -// `axis`. -// -// For example: -// -// ``` -// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -// shape(squeeze(t)) ==> [2, 3] -// ``` -// -// Or, to remove specific size 1 dimensions: -// -// ``` -// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -// shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] -// ``` +// input: A complex64 tensor. // -// Arguments: -// input: The `input` to squeeze. +// Returns A complex64 tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its inverse 1D Fourier transform. // -// Returns Contains the same data as `input`, but has one or more dimensions of -// size 1 removed. -func Squeeze(scope *Scope, input tf.Output, optional ...SqueezeAttr) (output tf.Output) { +// @compatibility(numpy) +// Equivalent to np.fft.ifft +// @end_compatibility +func IFFT(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Squeeze", + Type: "IFFT", Input: []tf.Input{ input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. -type ResourceApplyAdadeltaAttr func(optionalAttr) +// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. +type ResourceSparseApplyRMSPropAttr func(optionalAttr) -// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If True, updating of the var, accum and update_accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { +func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the adadelta scheme. +// Update '*var' according to the RMSProp algorithm. // -// accum = rho() * accum + (1 - rho()) * grad.square(); -// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; -// update_accum = rho() * update_accum + (1 - rho()) * update.square(); -// var -= update; +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom // // Arguments: // var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// accum_update: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. // grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { +func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -11545,292 +11131,373 @@ func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_ a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdadelta", + Type: "ResourceSparseApplyRMSProp", Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, + var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. -type NonMaxSuppressionAttr func(optionalAttr) - -// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. -// -// value: A float representing the threshold for deciding whether boxes -// overlap too much with respect to IOU. -// If not specified, defaults to 0.5 -func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { - return func(m optionalAttr) { - m["iou_threshold"] = value - } -} - -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. +// Returns the truth value of (x > y) element-wise. // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { +// *NOTE*: `Greater` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "NonMaxSuppression", + Type: "Greater", Input: []tf.Input{ - boxes, scores, max_output_size, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that emits `components` as a tuple of tensors once. -func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return +// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. +type SampleDistortedBoundingBoxAttr func(optionalAttr) + +// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["seed"] = value } - attrs := map[string]interface{}{"output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "TensorDataset", - Input: []tf.Input{ - tf.OutputList(components), - }, - Attrs: attrs, +} + +// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Component-wise multiplies a SparseTensor by a dense Tensor. +// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. // -// The output locations corresponding to the implicitly zero elements in the sparse -// tensor will be zero (i.e., will not take up storage space), regardless of the -// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). +// value: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. +// If not specified, defaults to 0.1 +func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["min_object_covered"] = value + } +} + +// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. // -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["aspect_ratio_range"] = value + } +} + +// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. // -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. +// value: The cropped area of the image must contain a fraction of the +// supplied image within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["area_range"] = value + } +} + +// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["max_attempts"] = value } - opspec := tf.OpSpec{ - Type: "SparseDenseCwiseMul", - Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, - }, +} + +// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// 2D real-valued fast Fourier transform. +// Generate a single randomly distorted bounding box for an image. // -// Computes the 2-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 2 dimensions of `input`. +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. // -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. // -// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, +// +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) +// +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) +// +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) +// ``` +// +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. // // Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. // -// @compatibility(numpy) -// Equivalent to np.fft.rfft2 -// @end_compatibility -func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RFFT2D", + Type: "SampleDistortedBoundingBox", Input: []tf.Input{ - input, fft_length, + image_size, bounding_boxes, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Pads a tensor with zeros. +// LRNAttr is an optional argument to LRN. +type LRNAttr func(optionalAttr) + +// LRNDepthRadius sets the optional depth_radius attribute to value. // -// This operation pads a `input` with zeros according to the `paddings` you -// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the -// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many zeros to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` -// in that dimension. +// value: 0-D. Half-width of the 1-D normalization window. +// If not specified, defaults to 5 +func LRNDepthRadius(value int64) LRNAttr { + return func(m optionalAttr) { + m["depth_radius"] = value + } +} + +// LRNBias sets the optional bias attribute to value. // -// The padded size of each dimension D of the output is: +// value: An offset (usually positive to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNBias(value float32) LRNAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNAlpha sets the optional alpha attribute to value. // -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNAlpha(value float32) LRNAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNBeta sets the optional beta attribute to value. // -// For example: +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNBeta(value float32) LRNAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Local Response Normalization. // -// ``` -// # 't' is [[1, 1], [2, 2]] -// # 'paddings' is [[1, 1], [2, 2]] -// # rank of 't' is 2 -// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] -// [0, 0, 1, 1, 0, 0] -// [0, 0, 2, 2, 0, 0] -// [0, 0, 0, 0, 0, 0]] -// ``` -func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { +// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last +// dimension), and each vector is normalized independently. Within a given vector, +// each component is divided by the weighted, squared sum of inputs within +// `depth_radius`. In detail, +// +// sqr_sum[a, b, c, d] = +// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) +// output = input / (bias + alpha * sqr_sum) ** beta +// +// For details, see [Krizhevsky et al., ImageNet classification with deep +// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// +// Arguments: +// input: 4-D. +func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Pad", + Type: "LRN", Input: []tf.Input{ - input, paddings, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Checks whether a resource handle-based variable has been initialized. -// -// Arguments: -// resource: the input resource handle. -// -// Returns a scalar boolean which is true if the variable has been -// initialized. -func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { +// Creates a dataset that zips together `input_datasets`. +func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "VarIsInitializedOp", + Type: "ZipDataset", Input: []tf.Input{ - resource, + tf.OutputList(input_datasets), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. +// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. +type ResourceSparseApplyAdagradAttr func(optionalAttr) + +// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. // -// The hash function is deterministic on the content of the string within the -// process and will never change. However, it is not suitable for cryptography. -// This function may be used when CPU time is scarce and inputs are trusted or -// unimportant. There is a risk of adversaries constructing inputs that all hash -// to the same bucket. To prevent this problem, use a strong hash function with -// `tf.string_to_hash_bucket_strong`. +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) // // Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { +// Returns the created operation. +func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "StringToHashBucketFast", + Type: "ResourceSparseApplyAdagrad", Input: []tf.Input{ - input, + var_, accum, lr, grad, indices, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. -type TensorArrayGatherV3Attr func(optionalAttr) +// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. +type StatelessRandomUniformAttr func(optionalAttr) -// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// StatelessRandomUniformDtype sets the optional dtype attribute to value. // -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { return func(m optionalAttr) { - m["element_shape"] = value + m["dtype"] = value } } -// Gather specific elements from the TensorArray into output `value`. +// Outputs deterministic pseudorandom random values from a uniform distribution. // -// All elements selected by `indices` must have the same shape. +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// handle: The handle to a TensorArray. -// indices: The locations in the TensorArray from which to read tensor elements. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// Returns All of the elements in the TensorArray, concatenated along a new -// axis (the new dimension 0). -func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { +// Returns Random values with specified shape. +func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayGatherV3", + Type: "StatelessRandomUniform", Input: []tf.Input{ - handle, indices, flow_in, + shape, seed, }, Attrs: attrs, } @@ -11838,48 +11505,48 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow return op.Output(0) } -// This op consumes a lock created by `MutexLock`. -// -// This op exists to consume a tensor created by `MutexLock` (other than -// direct control dependencies). It should be the only that consumes the tensor, -// and will raise an error if it is not. Its only purpose is to keep the -// mutex lock tensor alive until it is consumed by this op. -// -// **NOTE**: This operation must run on the same device as its input. This may -// be enforced via the `colocate_with` mechanism. +// Makes its input available to the next iteration. // // Arguments: -// mutex_lock: A tensor returned by `MutexLock`. +// data: The tensor to be made available to the next iteration. // -// Returns the created operation. -func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) { +// Returns The same tensor as `data`. +func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NextIteration", + Input: []tf.Input{ + data, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Output a fact about factorials. +func Fact(scope *Scope) (fact tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ConsumeMutexLock", - Input: []tf.Input{ - mutex_lock, - }, + Type: "Fact", } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns x / y element-wise for integer types. -// -// Truncation designates that negative numbers will round fractional quantities -// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different -// than Python semantics. See `FloorDiv` for a division function that matches -// Python Semantics. +// Elementwise computes the bitwise XOR of `x` and `y`. // -// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// The result will have those bits set, that are different in `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TruncateDiv", + Type: "BitwiseXor", Input: []tf.Input{ x, y, }, @@ -11888,149 +11555,243 @@ func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Restores tensors from a V2 checkpoint. +// Deserialize `SparseTensor` objects. // -// For backward compatibility with the V1 format, this Op currently allows -// restoring from a V1 checkpoint as well: -// - This Op first attempts to find the V2 index file pointed to by "prefix", and -// if found proceed to read it as a V2 checkpoint; -// - Otherwise the V1 read path is invoked. -// Relying on this behavior is not recommended, as the ability to fall back to read -// V1 might be deprecated and eventually removed. +// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where +// the last dimension stores serialized `SparseTensor` objects and the other N +// dimensions (N >= 0) correspond to a batch. The ranks of the original +// `SparseTensor` objects must all match. When the final `SparseTensor` is +// created, its rank is the rank of the incoming `SparseTensor` objects plus N; +// the sparse tensors have been concatenated along new dimensions, one for each +// batch. // -// By default, restores the named tensors in full. If the caller wishes to restore -// specific slices of stored tensors, "shape_and_slices" should be non-empty -// strings and correspondingly well-formed. +// The output `SparseTensor` object's shape values for the original dimensions +// are the max across the input `SparseTensor` objects' shape values for the +// corresponding dimensions. The new dimensions match the size of the batch. // -// Callers must ensure all the named tensors are indeed stored in the checkpoint. +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. // -// Arguments: -// prefix: Must have a single element. The prefix of a V2 checkpoint. -// tensor_names: shape {N}. The names of the tensors to be restored. -// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. -// Empty strings indicate that they are non-partitioned tensors. -// dtypes: shape {N}. The list of expected dtype for the tensors. Must match -// those stored in the checkpoint. +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: // -// Returns shape {N}. The restored tensors, whose shapes are read from the -// checkpoint directly. -func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// +// Arguments: +// serialized_sparse: The serialized `SparseTensor` objects. The last dimension +// must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "RestoreV2", + Type: "DeserializeSparse", Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, + serialized_sparse, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { - scope.UpdateErr("RestoreV2", err) - return - } - return tensors + return op.Output(0), op.Output(1), op.Output(2) } -// Receives a tensor value broadcast from another device. -func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} - opspec := tf.OpSpec{ - Type: "CollectiveBcastRecv", +// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. +type ResourceScatterNdUpdateAttr func(optionalAttr) - Attrs: attrs, +// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value. +// +// value: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, +// but may exhibit less contention. +// If not specified, defaults to true +func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr { + return func(m optionalAttr) { + m["use_locking"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Decode web-safe base64-encoded strings. +// Applies sparse `updates` to individual values or slices within a given // -// Input may or may not have padding at the end. See EncodeBase64 for padding. -// Web-safe means that input must use - and _ instead of + and /. +// variable according to `indices`. +// +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +// ``` +// +// For example, say we want to update 4 scattered elements to a rank-1 tensor to +// 8 elements. In Python, that update would look like this: +// +// ```python +// ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1] ,[7]]) +// updates = tf.constant([9, 10, 11, 12]) +// update = tf.scatter_nd_update(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(update) +// ``` +// +// The resulting update to ref would look like this: +// +// [1, 11, 3, 10, 9, 6, 7, 12] +// +// See @{tf.scatter_nd} for more details about how to make updates to +// slices. // // Arguments: -// input: Base64 strings to decode. +// ref: A resource handle. Must be from a VarHandleOp. +// indices: A Tensor. Must be one of the following types: int32, int64. +// A tensor of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of updated +// values to add to ref. // -// Returns Decoded strings. -func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { +// Returns the created operation. +func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DecodeBase64", + Type: "ResourceScatterNdUpdate", Input: []tf.Input{ - input, + ref, indices, updates, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Store the input tensor in the state of the current session. +// SqueezeAttr is an optional argument to Squeeze. +type SqueezeAttr func(optionalAttr) + +// SqueezeAxis sets the optional axis attribute to value. +// +// value: If specified, only squeezes the dimensions listed. The dimension +// index starts at 0. It is an error to squeeze a dimension that is not 1. Must +// be in the range `[-rank(input), rank(input))`. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func SqueezeAxis(value []int64) SqueezeAttr { + return func(m optionalAttr) { + m["squeeze_dims"] = value + } +} + +// Removes dimensions of size 1 from the shape of a tensor. +// +// Given a tensor `input`, this operation returns a tensor of the same type with +// all dimensions of size 1 removed. If you don't want to remove all size 1 +// dimensions, you can remove specific size 1 dimensions by specifying +// `axis`. +// +// For example: +// +// ``` +// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +// shape(squeeze(t)) ==> [2, 3] +// ``` +// +// Or, to remove specific size 1 dimensions: +// +// ``` +// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +// shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] +// ``` // // Arguments: -// value: The tensor to be stored. +// input: The `input` to squeeze. // -// Returns The handle for the tensor stored in the session state, represented -// as a string. -func GetSessionHandle(scope *Scope, value tf.Output) (handle tf.Output) { +// Returns Contains the same data as `input`, but has one or more dimensions of +// size 1 removed. +func Squeeze(scope *Scope, input tf.Output, optional ...SqueezeAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "GetSessionHandle", + Type: "Squeeze", Input: []tf.Input{ - value, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. -type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) +// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. +type ResourceApplyAdadeltaAttr func(optionalAttr) -// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by +// value: If True, updating of the var, accum and update_accum tensors will be protected by // a lock; otherwise the behavior is undefined, but may exhibit less contention. // If not specified, defaults to false -func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { +func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. +// Update '*var' according to the adadelta scheme. // -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// prox_v = var -// prox_v -= lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// accum = rho() * accum + (1 - rho()) * grad.square(); +// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; +// update_accum = rho() * update_accum + (1 - rho()) * update.square(); +// var -= update; // // Arguments: // var_: Should be from a Variable(). // accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. +// accum_update: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. // // Returns the created operation. -func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { +func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -12039,55 +11800,69 @@ func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.O a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalAdagrad", + Type: "ResourceApplyAdadelta", Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, indices, + var_, accum, accum_update, lr, rho, epsilon, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. -type MaxPool3DGradAttr func(optionalAttr) +// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. +type NonMaxSuppressionAttr func(optionalAttr) -// MaxPool3DGradDataFormat sets the optional data_format attribute to value. +// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { +// value: A float representing the threshold for deciding whether boxes +// overlap too much with respect to IOU. +// If not specified, defaults to 0.5 +func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { return func(m optionalAttr) { - m["data_format"] = value + m["iou_threshold"] = value } } -// Computes gradients of max pooling function. +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool3DGrad", + Type: "NonMaxSuppression", Input: []tf.Input{ - orig_input, orig_output, grad, + boxes, scores, max_output_size, }, Attrs: attrs, } @@ -12095,92 +11870,174 @@ func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, gr return op.Output(0) } -// SparseReduceSumAttr is an optional argument to SparseReduceSum. -type SparseReduceSumAttr func(optionalAttr) +// Creates a dataset that emits `components` as a tuple of tensors once. +func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TensorDataset", + Input: []tf.Input{ + tf.OutputList(components), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. +// Component-wise multiplies a SparseTensor by a dense Tensor. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { - return func(m optionalAttr) { - m["keep_dims"] = value +// The output locations corresponding to the implicitly zero elements in the sparse +// tensor will be zero (i.e., will not take up storage space), regardless of the +// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). +// +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. +// +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseDenseCwiseMul", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, dense, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the sum of elements across dimensions of a SparseTensor. +// 2D real-valued fast Fourier transform. // -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. +// Computes the 2-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 2 dimensions of `input`. // -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. // -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. // -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { +// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfft2 +// @end_compatibility +func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SparseReduceSum", + Type: "RFFT2D", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + input, fft_length, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// VariableShapeAttr is an optional argument to VariableShape. -type VariableShapeAttr func(optionalAttr) +// Pads a tensor with zeros. +// +// This operation pads a `input` with zeros according to the `paddings` you +// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many zeros to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` +// in that dimension. +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 1], [2, 2]] +// # 'paddings' is [[1, 1], [2, 2]] +// # rank of 't' is 2 +// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] +// [0, 0, 1, 1, 0, 0] +// [0, 0, 2, 2, 0, 0] +// [0, 0, 0, 0, 0, 0]] +// ``` +func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Pad", + Input: []tf.Input{ + input, paddings, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// VariableShapeOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func VariableShapeOutType(value tf.DataType) VariableShapeAttr { - return func(m optionalAttr) { - m["out_type"] = value +// Checks whether a resource handle-based variable has been initialized. +// +// Arguments: +// resource: the input resource handle. +// +// Returns a scalar boolean which is true if the variable has been +// initialized. +func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "VarIsInitializedOp", + Input: []tf.Input{ + resource, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the shape of the variable pointed to by `resource`. +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// This operation returns a 1-D integer tensor representing the shape of `input`. +// The hash function is deterministic on the content of the string within the +// process and will never change. However, it is not suitable for cryptography. +// This function may be used when CPU time is scarce and inputs are trusted or +// unimportant. There is a risk of adversaries constructing inputs that all hash +// to the same bucket. To prevent this problem, use a strong hash function with +// `tf.string_to_hash_bucket_strong`. // -// For example: +// Arguments: +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. // -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] -// ``` -func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "VariableShape", + Type: "StringToHashBucketFast", Input: []tf.Input{ input, }, @@ -12190,241 +12047,245 @@ func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) return op.Output(0) } -// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. -type SparseToSparseSetOperationAttr func(optionalAttr) +// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. +type TensorArrayGatherV3Attr func(optionalAttr) -// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { +// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { return func(m optionalAttr) { - m["validate_indices"] = value + m["element_shape"] = value } } -// Applies set operation along last dimension of 2 `SparseTensor` inputs. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the -// order and range of `set1` and `set2` indices. -// -// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, -// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same -// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. -// -// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, -// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same -// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. -// -// If `validate_indices` is `True`, this op validates the order and range of `set1` -// and `set2` indices. +// Gather specific elements from the TensorArray into output `value`. // -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. +// All elements selected by `indices` must have the same shape. // // Arguments: -// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must -// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the -// max set size across `0...n-1` dimensions. -// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must -// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the -// max set size across `0...n-1` dimensions. -// +// handle: The handle to a TensorArray. +// indices: The locations in the TensorArray from which to read tensor elements. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. // -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { +// Returns All of the elements in the TensorArray, concatenated along a new +// axis (the new dimension 0). +func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"set_operation": set_operation} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SparseToSparseSetOperation", + Type: "TensorArrayGatherV3", Input: []tf.Input{ - set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, + handle, indices, flow_in, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Computes softmax cross entropy cost and gradients to backpropagate. +// This op consumes a lock created by `MutexLock`. // -// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept -// a matrix of label probabilities, but rather a single label per row -// of features. This label is considered to have probability 1.0 for the -// given row. +// This op exists to consume a tensor created by `MutexLock` (other than +// direct control dependencies). It should be the only that consumes the tensor, +// and will raise an error if it is not. Its only purpose is to keep the +// mutex lock tensor alive until it is consumed by this op. // -// Inputs are the logits, not probabilities. +// **NOTE**: This operation must run on the same device as its input. This may +// be enforced via the `colocate_with` mechanism. // // Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size vector with values in [0, num_classes). -// This is the label for the given minibatch entry. +// mutex_lock: A tensor returned by `MutexLock`. // -// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). -func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { +// Returns the created operation. +func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSoftmaxCrossEntropyWithLogits", + Type: "ConsumeMutexLock", Input: []tf.Input{ - features, labels, + mutex_lock, }, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Fast Fourier transform. -// -// Computes the 1-dimensional discrete Fourier transform over the inner-most -// dimension of `input`. -// -// Arguments: -// input: A complex64 tensor. +// Returns x / y element-wise for integer types. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its 1D Fourier transform. +// Truncation designates that negative numbers will round fractional quantities +// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different +// than Python semantics. See `FloorDiv` for a division function that matches +// Python Semantics. // -// @compatibility(numpy) -// Equivalent to np.fft.fft -// @end_compatibility -func FFT(scope *Scope, input tf.Output) (output tf.Output) { +// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FFT", + Type: "TruncateDiv", Input: []tf.Input{ - input, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Transforms a serialized tensorflow.TensorProto proto into a Tensor. +// Restores tensors from a V2 checkpoint. +// +// For backward compatibility with the V1 format, this Op currently allows +// restoring from a V1 checkpoint as well: +// - This Op first attempts to find the V2 index file pointed to by "prefix", and +// if found proceed to read it as a V2 checkpoint; +// - Otherwise the V1 read path is invoked. +// Relying on this behavior is not recommended, as the ability to fall back to read +// V1 might be deprecated and eventually removed. +// +// By default, restores the named tensors in full. If the caller wishes to restore +// specific slices of stored tensors, "shape_and_slices" should be non-empty +// strings and correspondingly well-formed. +// +// Callers must ensure all the named tensors are indeed stored in the checkpoint. // // Arguments: -// serialized: A scalar string containing a serialized TensorProto proto. -// out_type: The type of the serialized tensor. The provided type must match the -// type of the serialized tensor and no implicit conversion will take place. +// prefix: Must have a single element. The prefix of a V2 checkpoint. +// tensor_names: shape {N}. The names of the tensors to be restored. +// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. +// Empty strings indicate that they are non-partitioned tensors. +// dtypes: shape {N}. The list of expected dtype for the tensors. Must match +// those stored in the checkpoint. // -// Returns A Tensor of type `out_type`. -func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { +// Returns shape {N}. The restored tensors, whose shapes are read from the +// checkpoint directly. +func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} + attrs := map[string]interface{}{"dtypes": dtypes} opspec := tf.OpSpec{ - Type: "ParseTensor", + Type: "RestoreV2", Input: []tf.Input{ - serialized, + prefix, tensor_names, shape_and_slices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { + scope.UpdateErr("RestoreV2", err) + return + } + return tensors } -// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. -type MaxPoolWithArgmaxAttr func(optionalAttr) +// Receives a tensor value broadcast from another device. +func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} + opspec := tf.OpSpec{ + Type: "CollectiveBcastRecv", -// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. -// If not specified, defaults to DT_INT64 -func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { - return func(m optionalAttr) { - m["Targmax"] = value + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Performs max pooling on the input and outputs both max values and indices. +// Decode web-safe base64-encoded strings. // -// The indices in `argmax` are flattened, so that a maximum value at position -// `[b, y, x, c]` becomes flattened index -// `((b * height + y) * width + x) * channels + c`. +// Input may or may not have padding at the end. See EncodeBase64 for padding. +// Web-safe means that input must use - and _ instead of + and /. // -// The indices returned are always in `[0, height) x [0, width)` before flattening, -// even if padding is involved and the mathematically correct answer is outside -// (either negative or too large). This is a bug, but fixing it is difficult to do -// in a safe backwards compatible way, especially due to flattening. +// Arguments: +// input: Base64 strings to decode. +// +// Returns Decoded strings. +func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DecodeBase64", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Store the input tensor in the state of the current session. // // Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// value: The tensor to be stored. // -// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. -func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { +// Returns The handle for the tensor stored in the session state, represented +// as a string. +func GetSessionHandle(scope *Scope, value tf.Output) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MaxPoolWithArgmax", + Type: "GetSessionHandle", Input: []tf.Input{ - input, + value, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. -type ResourceSparseApplyAdagradDAAttr func(optionalAttr) +// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. +type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) -// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. // // value: If True, updating of the var and accum tensors will be protected by // a lock; otherwise the behavior is undefined, but may exhibit less contention. // If not specified, defaults to false -func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { +func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. +// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// prox_v = var +// prox_v -= lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} // // Arguments: // var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// accum: Should be from a Variable(). // lr: Learning rate. Must be a scalar. // l1: L1 regularization. Must be a scalar. // l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. // // Returns the created operation. -func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { +func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -12433,209 +12294,361 @@ func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumul a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagradDA", + Type: "ResourceSparseApplyProximalAdagrad", Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, + var_, accum, lr, l1, l2, grad, indices, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) +// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. +type MaxPool3DGradAttr func(optionalAttr) -// EncodeJpegFormat sets the optional format attribute to value. +// MaxPool3DGradDataFormat sets the optional data_format attribute to value. // -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { return func(m optionalAttr) { - m["format"] = value + m["data_format"] = value } } -// EncodeJpegQuality sets the optional quality attribute to value. +// Computes gradients of max pooling function. // -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["quality"] = value +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool3DGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodeJpegProgressive sets the optional progressive attribute to value. +// SparseReduceSumAttr is an optional argument to SparseReduceSum. +type SparseReduceSumAttr func(optionalAttr) + +// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. // -// value: If True, create a JPEG that loads progressively (coarse to fine). +// value: If true, retain reduced dimensions with length 1. // If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { +func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { return func(m optionalAttr) { - m["progressive"] = value + m["keep_dims"] = value } } -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. +// Computes the sum of elements across dimensions of a SparseTensor. // -// value: If True, spend CPU/RAM to reduce size with no quality change. -// If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["optimize_size"] = value +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseReduceSum", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. -// -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { +// VariableShapeAttr is an optional argument to VariableShape. +type VariableShapeAttr func(optionalAttr) + +// VariableShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func VariableShapeOutType(value tf.DataType) VariableShapeAttr { return func(m optionalAttr) { - m["chroma_downsampling"] = value + m["out_type"] = value } } -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// Returns the shape of the variable pointed to by `resource`. // -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["density_unit"] = value +// This operation returns a 1-D integer tensor representing the shape of `input`. +// +// For example: +// +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] +// ``` +func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "VariableShape", + Input: []tf.Input{ + input, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// EncodeJpegXDensity sets the optional x_density attribute to value. -// -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { +// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. +type SparseToSparseSetOperationAttr func(optionalAttr) + +// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { return func(m optionalAttr) { - m["x_density"] = value + m["validate_indices"] = value } } -// EncodeJpegYDensity sets the optional y_density attribute to value. +// Applies set operation along last dimension of 2 `SparseTensor` inputs. // -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["y_density"] = value +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// +// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the +// order and range of `set1` and `set2` indices. +// +// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, +// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same +// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, +// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same +// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set1` +// and `set2` indices. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must +// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the +// max set size across `0...n-1` dimensions. +// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must +// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the +// max set size across `0...n-1` dimensions. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { + if scope.Err() != nil { + return } -} - -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. -// -// value: If not empty, embed this XMP metadata in the image header. -// If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["xmp_metadata"] = value + attrs := map[string]interface{}{"set_operation": set_operation} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseToSparseSetOperation", + Input: []tf.Input{ + set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// JPEG-encode an image. +// Computes softmax cross entropy cost and gradients to backpropagate. // -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept +// a matrix of label probabilities, but rather a single label per row +// of features. This label is considered to have probability 1.0 for the +// given row. // -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: +// Inputs are the logits, not probabilities. // -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. +// Arguments: +// features: batch_size x num_classes matrix +// labels: batch_size vector with values in [0, num_classes). +// This is the label for the given minibatch entry. // -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: +// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). +func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSoftmaxCrossEntropyWithLogits", + Input: []tf.Input{ + features, labels, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Fast Fourier transform. // -// * 1: Output a grayscale image. -// * 3: Output an RGB image. +// Computes the 1-dimensional discrete Fourier transform over the inner-most +// dimension of `input`. // // Arguments: -// image: 3-D with shape `[height, width, channels]`. +// input: A complex64 tensor. // -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { +// Returns A complex64 tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft +// @end_compatibility +func FFT(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "EncodeJpeg", + Type: "FFT", Input: []tf.Input{ - image, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MultinomialAttr is an optional argument to Multinomial. -type MultinomialAttr func(optionalAttr) - -// MultinomialSeed sets the optional seed attribute to value. +// Transforms a serialized tensorflow.TensorProto proto into a Tensor. // -// value: If either seed or seed2 is set to be non-zero, the internal random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func MultinomialSeed(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// MultinomialSeed2 sets the optional seed2 attribute to value. +// Arguments: +// serialized: A scalar string containing a serialized TensorProto proto. +// out_type: The type of the serialized tensor. The provided type must match the +// type of the serialized tensor and no implicit conversion will take place. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func MultinomialSeed2(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed2"] = value +// Returns A Tensor of type `out_type`. +func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"out_type": out_type} + opspec := tf.OpSpec{ + Type: "ParseTensor", + Input: []tf.Input{ + serialized, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. +type MaxPoolWithArgmaxAttr func(optionalAttr) + +// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. // If not specified, defaults to DT_INT64 -func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { +func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { return func(m optionalAttr) { - m["output_dtype"] = value + m["Targmax"] = value } } -// Draws samples from a multinomial distribution. +// Performs max pooling on the input and outputs both max values and indices. +// +// The indices in `argmax` are flattened, so that a maximum value at position +// `[b, y, x, c]` becomes flattened index +// `((b * height + y) * width + x) * channels + c`. +// +// The indices returned are always in `[0, height) x [0, width)` before flattening, +// even if padding is involved and the mathematically correct answer is outside +// (either negative or too large). This is a bug, but fixing it is difficult to do +// in a safe backwards compatible way, especially due to flattening. // // Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. +// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { +// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. +func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Multinomial", + Type: "MaxPoolWithArgmax", Input: []tf.Input{ - logits, num_samples, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } // Returns the truth value of NOT x element-wise. @@ -13157,62 +13170,6 @@ func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } -// Inverse 2D fast Fourier transform. -// -// Computes the inverse 2-dimensional discrete Fourier transform over the -// inner-most 2 dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their inverse 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifft2 -// @end_compatibility -func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 2D fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform over the inner-most -// 2 dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft2 -// @end_compatibility -func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. type ResourceApplyProximalGradientDescentAttr func(optionalAttr) @@ -15299,51 +15256,26 @@ func BoostedTreesEnsembleResourceHandleOpContainer(value string) BoostedTreesEns } // BoostedTreesEnsembleResourceHandleOpSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func BoostedTreesEnsembleResourceHandleOpSharedName(value string) BoostedTreesEnsembleResourceHandleOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a handle to a BoostedTreesEnsembleResource -func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTreesEnsembleResourceHandleOpAttr) (resource tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BoostedTreesEnsembleResourceHandleOp", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Concatenates tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { +// If not specified, defaults to "" +func BoostedTreesEnsembleResourceHandleOpSharedName(value string) BoostedTreesEnsembleResourceHandleOpAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a handle to a BoostedTreesEnsembleResource +func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTreesEnsembleResourceHandleOpAttr) (resource tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Concat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), - }, + Type: "BoostedTreesEnsembleResourceHandleOp", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -16267,6 +16199,62 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D return op.Output(0) } +// 2D fast Fourier transform. +// +// Computes the 2-dimensional discrete Fourier transform over the inner-most +// 2 dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft2 +// @end_compatibility +func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse 2D fast Fourier transform. +// +// Computes the inverse 2-dimensional discrete Fourier transform over the +// inner-most 2 dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft2 +// @end_compatibility +func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. type ResourceApplyRMSPropAttr func(optionalAttr) @@ -17777,77 +17765,6 @@ func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes [ return op.Output(0), op.Output(1), op.Output(2) } -// Concatenates quantized tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// input_mins: The minimum scalar values for each of the input tensors. -// input_maxes: The maximum scalar values for each of the input tensors. -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "QuantizedConcat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Slice a `SparseTensor` based on the `start` and `size`. -// -// For example, if the input is -// -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] -// -// Graphically the output tensors are: -// -// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] -// [ a ] -// [b c ] -// -// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] -// [ d e ] -// [ ] -// -// Arguments: -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// start: 1-D. tensor represents the start of the slice. -// size: 1-D. tensor represents the size of the slice. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. -// -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSlice", - Input: []tf.Input{ - indices, values, shape, start, size, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // Returns the element-wise min of two SparseTensors. // // Assumes the two SparseTensors have the same shape, i.e., no broadcasting. @@ -17978,52 +17895,6 @@ func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype return op.Output(0), op.Output(1), op.Output(2) } -// MaxPoolAttr is an optional argument to MaxPool. -type MaxPoolAttr func(optionalAttr) - -// MaxPoolDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolDataFormat(value string) MaxPoolAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs max pooling on the input. -// -// Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor. -func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Assigns a new value to a variable. // // Any ReadVariableOp with a control dependency on this op is guaranteed to return @@ -18601,72 +18472,9 @@ func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feat } if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { scope.UpdateErr("SdcaOptimizer", err) - return - } - return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights -} - -// SparseMatMulAttr is an optional argument to SparseMatMul. -type SparseMatMulAttr func(optionalAttr) - -// SparseMatMulTransposeA sets the optional transpose_a attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeA(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// SparseMatMulTransposeB sets the optional transpose_b attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeB(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["a_is_sparse"] = value - } -} - -// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["b_is_sparse"] = value - } -} - -// Multiply matrix "a" by matrix "b". -// -// The inputs must be two-dimensional matrices and the inner dimension of "a" must -// match the outer dimension of "b". This op is optimized for the case where at -// least one of "a" or "b" is sparse. The breakeven for using this versus a dense -// matrix multiply on one platform was 30% zero values in the sparse matrix. -// -// The gradient computation of this operation will only take advantage of sparsity -// in the input gradient when that gradient comes from a Relu. -func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseMatMul", - Input: []tf.Input{ - a, b, - }, - Attrs: attrs, + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights } // ShapeAttr is an optional argument to Shape. @@ -19514,6 +19322,228 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or return op.Output(0) } +// LRNGradAttr is an optional argument to LRNGrad. +type LRNGradAttr func(optionalAttr) + +// LRNGradDepthRadius sets the optional depth_radius attribute to value. +// +// value: A depth radius. +// If not specified, defaults to 5 +func LRNGradDepthRadius(value int64) LRNGradAttr { + return func(m optionalAttr) { + m["depth_radius"] = value + } +} + +// LRNGradBias sets the optional bias attribute to value. +// +// value: An offset (usually > 0 to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNGradBias(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNGradAlpha sets the optional alpha attribute to value. +// +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNGradAlpha(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNGradBeta sets the optional beta attribute to value. +// +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNGradBeta(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Gradients for Local Response Normalization. +// +// Arguments: +// input_grads: 4-D with shape `[batch, height, width, channels]`. +// input_image: 4-D with shape `[batch, height, width, channels]`. +// output_image: 4-D with shape `[batch, height, width, channels]`. +// +// Returns The gradients for LRN. +func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LRNGrad", + Input: []tf.Input{ + input_grads, input_image, output_image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AnyAttr is an optional argument to Any. +type AnyAttr func(optionalAttr) + +// AnyKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AnyKeepDims(value bool) AnyAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the "logical or" of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Any", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a sequence of numbers. +// +// This operation creates a sequence of numbers that begins at `start` and +// extends by increments of `delta` up to but not including `limit`. +// +// For example: +// +// ``` +// # 'start' is 3 +// # 'limit' is 18 +// # 'delta' is 3 +// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] +// ``` +// +// Arguments: +// start: 0-D (scalar). First entry in the sequence. +// limit: 0-D (scalar). Upper limit of sequence, exclusive. +// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. +// +// Returns 1-D. +func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Range", + Input: []tf.Input{ + start, limit, delta, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. +type DestroyResourceOpAttr func(optionalAttr) + +// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. +// +// value: whether to ignore the error when the resource +// doesn't exist. +// If not specified, defaults to true +func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { + return func(m optionalAttr) { + m["ignore_lookup_error"] = value + } +} + +// Deletes the resource specified by the handle. +// +// All subsequent operations using the resource will result in a NotFound +// error status. +// +// Arguments: +// resource: handle to the resource to delete. +// +// Returns the created operation. +func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DestroyResourceOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Generates values in an interval. +// +// A sequence of `num` evenly-spaced values are generated beginning at `start`. +// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +// so that the last one is exactly `stop`. +// +// For example: +// +// ``` +// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] +// ``` +// +// Arguments: +// start: First entry in the range. +// stop: Last entry in the range. +// num: Number of values to generate. +// +// Returns 1-D. The generated values. +func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LinSpace", + Input: []tf.Input{ + start, stop, num, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ComplexAttr is an optional argument to Complex. type ComplexAttr func(optionalAttr) @@ -30680,33 +30710,3 @@ func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Outpu op := scope.AddOperation(opspec) return op.Output(0) } - -// Converts a flat index or array of flat indices into a tuple of -// -// coordinate arrays. -// -// @compatibility(numpy) -// Equivalent to np.unravel_index -// @end_compatibility -// -// Arguments: -// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the -// flattened version of an array of dimensions dims. -// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling -// indices. -// -// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the -// same shape as the indices array. -func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "UnravelIndex", - Input: []tf.Input{ - indices, dims, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 2d25c04dc9b1d0bc2ae831f98c0879e73a6bfafa..f3338f6595793df82380f4ce63058ba4285c91dd 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -131,13 +131,9 @@ func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) } runtime.SetFinalizer(t, (*Tensor).finalize) raw := tensorData(t.c) - n, err := r.Read(raw) - if err != nil { + if _, err := io.ReadFull(r, raw); err != nil { return nil, err } - if uintptr(n) != nbytes { - return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) - } return t, nil } diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 793c36dd4db28fc5fdb713095c6d1d6713367a7a..dc533cd3e1c7198f902b2db850e8daff50f4cdeb 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -18,6 +18,7 @@ package tensorflow import ( "bytes" + "io" "reflect" "testing" ) @@ -226,6 +227,54 @@ func TestTensorSerializationErrors(t *testing.T) { } } +func TestReadTensorReadAll(t *testing.T) { + // Get the bytes of a tensor. + a := []float32{1.1, 1.2, 1.3} + ats, err := NewTensor(a) + if err != nil { + t.Fatal(err) + } + abuf := new(bytes.Buffer) + if _, err := ats.WriteContentsTo(abuf); err != nil { + t.Fatal(err) + } + + // Get the bytes of another tensor. + b := []float32{1.1, 1.2, 1.3} + bts, err := NewTensor(b) + if err != nil { + t.Fatal(err) + } + bbuf := new(bytes.Buffer) + if _, err := bts.WriteContentsTo(bbuf); err != nil { + t.Fatal(err) + } + + // Check that ReadTensor reads all bytes of both tensors, when the situation + // requires one than reads. + abbuf := io.MultiReader(abuf, bbuf) + abts, err := ReadTensor(Float, []int64{2, 3}, abbuf) + if err != nil { + t.Fatal(err) + } + abtsf32 := abts.Value().([][]float32) + expected := [][]float32{a, b} + + if len(abtsf32) != 2 { + t.Fatalf("first dimension %d is not 2", len(abtsf32)) + } + for i := 0; i < 2; i++ { + if len(abtsf32[i]) != 3 { + t.Fatalf("second dimension %d is not 3", len(abtsf32[i])) + } + for j := 0; j < 3; j++ { + if abtsf32[i][j] != expected[i][j] { + t.Errorf("value at %d %d not equal %f %f", i, j, abtsf32[i][j], expected[i][j]) + } + } + } +} + func benchmarkNewTensor(b *testing.B, v interface{}) { for i := 0; i < b.N; i++ { if t, err := NewTensor(v); err != nil || t == nil { diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md index 2f1ce253b2facb6d86d5c44b60668823f660ae7e..c7382ff23138cd8121718d0b7552da0f0a2d78af 100644 --- a/tensorflow/java/README.md +++ b/tensorflow/java/README.md @@ -1,7 +1,7 @@ # TensorFlow for Java > *WARNING*: The TensorFlow Java API is not currently covered by the TensorFlow -> [API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). +> [API stability guarantees](https://www.tensorflow.org/guide/version_semantics). > > For using TensorFlow on Android refer instead to > [contrib/android](https://www.tensorflow.org/code/tensorflow/contrib/android), @@ -23,8 +23,7 @@ native libraries will need to be built from source. 2. Setup the environment to build TensorFlow from source code ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux) - or [Mac OS - X](https://www.tensorflow.org/install/install_sources#PrepareMac)). + or [macOS](https://www.tensorflow.org/install/install_sources#PrepareMac)). If you'd like to skip reading those details and do not care about GPU support, try the following: diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 38e87b16399e7b9344f654b553fa1623b6b2d9cd..a7fa9ea5cc78f9d83cfb105f09837e958c60d5b4 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.9.0-rc1 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index 36c984e280199c97a6e07516cc84290fa91e6b27..83aae29f1ea0f893c40597a1be6f77668d8206e9 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.9.0-rc1 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 4c846de05ad415bedf2c14b6a07ff8d5bc6f11b8..50bd8ee5f9e6d268976540ca8180380447bc8f18 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.9.0-rc1 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index acab08f58c074144472f80eeaf894e25ad8389f8..b4746794ea9e417bb0bb9253ca356976a48eb1e8 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.9.0-rc0 + 1.9.0-rc1 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index eb0a952c7d2f8960a387fa63e6c257e33b80bcbb..618a2a124c77240b0a2b65f33577a6330929ae83 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.9.0-rc1 ../ proto diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 48668a47f2839e06eba774cae75b03663985ea28..157c4b8e82d6b8062ce8c9c98432cfe97a20d190 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.9.0-rc1 ../ tensorflow diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 2df69ee29996304569320c1dbbcaa46f214d4ea0..d5bd99bdd9d71f73288661380ec45e76c797fa75 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -36,20 +36,21 @@ namespace java { namespace { constexpr const char kLicense[] = - "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" - "\n" - "Licensed under the Apache License, Version 2.0 (the \"License\");\n" - "you may not use this file except in compliance with the License.\n" - "You may obtain a copy of the License at\n" - "\n" - " http://www.apache.org/licenses/LICENSE-2.0\n" - "\n" - "Unless required by applicable law or agreed to in writing, software\n" - "distributed under the License is distributed on an \"AS IS\" BASIS,\n" - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "See the License for the specific language governing permissions and\n" - "limitations under the License.\n" - "=======================================================================*/\n"; + "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" + "\n" + "Licensed under the Apache License, Version 2.0 (the \"License\");\n" + "you may not use this file except in compliance with the License.\n" + "You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + "Unless required by applicable law or agreed to in writing, software\n" + "distributed under the License is distributed on an \"AS IS\" BASIS,\n" + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "See the License for the specific language governing permissions and\n" + "limitations under the License.\n" + "=======================================================================*/" + "\n"; // There is three different modes to render an op class, depending on the // number and type of outputs it has: diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index f0e4bcca8253be60751066968938be8226b793e2..63e99fbb04fd6ba34f2bbd2bc3fe7644a31ddf7f 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -97,6 +97,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, *iterable_out = true; visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int())); } + Type type = Type::Wildcard(); if (arg_def.type() != DataType::DT_INVALID) { // resolve type from DataType diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 3524160d876ac89306203891357f27946d9e368f..796d6a62dcf8551d8d68d9ff62077e7f09db4401 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,6 +15,18 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; import java.io.IOException; import java.util.Collection; import java.util.Collections; @@ -23,7 +35,6 @@ import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; - import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; @@ -44,19 +55,6 @@ import javax.lang.model.util.ElementFilter; import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; -import com.google.common.base.CaseFormat; -import com.google.common.base.Strings; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.JavaFile; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; - /** * A compile-time Processor that aggregates classes annotated with {@link * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please @@ -115,10 +113,12 @@ public final class OperatorProcessor extends AbstractProcessor { // generated our code, flag the location of each such class. if (hasRun) { for (Element e : annotated) { - error(e, "The Operator processor has already processed @Operator annotated sources\n" + - "and written out an Ops API. It cannot process additional @Operator sources.\n" + - "One reason this can happen is if other annotation processors generate\n" + - "new @Operator source files."); + error( + e, + "The Operator processor has already processed @Operator annotated sources\n" + + "and written out an Ops API. It cannot process additional @Operator sources.\n" + + "One reason this can happen is if other annotation processors generate\n" + + "new @Operator source files."); } return true; } @@ -146,9 +146,11 @@ public final class OperatorProcessor extends AbstractProcessor { return Collections.singleton("org.tensorflow.op.annotation.Operator"); } - private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final Pattern JAVADOC_TAG_PATTERN = + Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); - private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_OPERATOR = + ClassName.get("org.tensorflow.op.annotation", "Operator"); private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); private static final TypeName T_STRING = ClassName.get(String.class); @@ -167,20 +169,17 @@ public final class OperatorProcessor extends AbstractProcessor { private void write(TypeSpec spec) { try { - JavaFile.builder("org.tensorflow.op", spec) - .skipJavaLangImports(true) - .build() - .writeTo(filer); + JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); } catch (IOException e) { throw new AssertionError(e); } } private void writeApi(Multimap groupedMethods) { - Map groups = new HashMap(); - + Map groups = new HashMap<>(); + // Generate a API class for each group collected other than the default one (= empty string) - for (Map.Entry> entry: groupedMethods.asMap().entrySet()) { + for (Map.Entry> entry : groupedMethods.asMap().entrySet()) { if (!entry.getKey().isEmpty()) { TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); write(groupClass); @@ -193,12 +192,17 @@ public final class OperatorProcessor extends AbstractProcessor { } private boolean collectOpsMethods( - RoundEnvironment roundEnv, Multimap groupedMethods, TypeElement annotation) { + RoundEnvironment roundEnv, + Multimap groupedMethods, + TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. if (!(e instanceof TypeElement)) { - error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString()); + error( + e, + "@Operator can only be applied to classes, but this is a %s", + e.getKind().toString()); result = false; continue; } @@ -210,38 +214,42 @@ public final class OperatorProcessor extends AbstractProcessor { } return result; } - - private void collectOpMethods(Multimap groupedMethods, TypeElement opClass, TypeElement annotation) { + + private void collectOpMethods( + Multimap groupedMethods, TypeElement opClass, TypeElement annotation) { AnnotationMirror am = getAnnotationMirror(opClass, annotation); String groupName = getAnnotationElementValueAsString("group", am); String methodName = getAnnotationElementValueAsString("name", am); ClassName opClassName = ClassName.get(opClass); if (Strings.isNullOrEmpty(methodName)) { - methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); } - // Build a method for each @Operator found in the class path. There should be one method per operation factory called + // Build a method for each @Operator found in the class path. There should be one method per + // operation factory called // "create", which takes in parameter a scope and, optionally, a list of arguments for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { - if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) { + if (opMethod.getModifiers().contains(Modifier.STATIC) + && opMethod.getSimpleName().contentEquals("create")) { MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); groupedMethods.put(groupName, method); } } } - private MethodSpec buildOpMethod(String methodName, ClassName opClassName, ExecutableElement factoryMethod) { + private MethodSpec buildOpMethod( + String methodName, ClassName opClassName, ExecutableElement factoryMethod) { MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName) - .addModifiers(Modifier.PUBLIC) - .returns(TypeName.get(factoryMethod.getReturnType())) - .varargs(factoryMethod.isVarArgs()) - .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); - for (TypeParameterElement tp: factoryMethod.getTypeParameters()) { + for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); builder.addTypeVariable(tvn); } - for (TypeMirror thrownType: factoryMethod.getThrownTypes()) { + for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { builder.addException(TypeName.get(thrownType)); } StringBuilder call = new StringBuilder("return $T.create(scope"); @@ -259,13 +267,17 @@ public final class OperatorProcessor extends AbstractProcessor { call.append(")"); builder.addStatement(call.toString(), opClassName); return builder.build(); - } - + } + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { StringBuilder javadoc = new StringBuilder(); - javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n"); + javadoc + .append("Adds an {@link ") + .append(opClassName.simpleName()) + .append("} operation to the graph\n\n"); - // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the + // Add all javadoc tags found in the operator factory method but the first one, which should be + // in all cases the // 'scope' parameter that is implicitly passed by this API Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); boolean firstParam = true; @@ -277,136 +289,144 @@ public final class OperatorProcessor extends AbstractProcessor { } else { javadoc.append(tag).append('\n'); } - } + } javadoc.append("@see {@link ").append(opClassName).append("}\n"); return javadoc.toString(); } - + private static TypeSpec buildGroupClass(String group, Collection methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope"); - + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + TypeSpec.Builder builder = TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + - "@see {@link $T}\n", group, T_GRAPH, T_OPS) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", + group, + T_GRAPH, + T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); builder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); return builder.build(); } - private static TypeSpec buildTopClass(Map groupToClass, Collection methods) { + private static TypeSpec buildTopClass( + Map groupToClass, Collection methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addModifiers(Modifier.PRIVATE) - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope", T_SCOPE); + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); - for (Map.Entry entry: groupToClass.entrySet()) { + for (Map.Entry entry : groupToClass.entrySet()) { ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); } TypeSpec.Builder opsBuilder = TypeSpec.classBuilder("Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for building a {@link $T} with operation wrappers\n

\n" + - "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" + - "by this API or one of its subgroup.\n

Example usage:\n

{@code\n" +
-            "try (Graph g = new Graph()) {\n" +
-            "  Ops ops = new Ops(g);\n" +
-            "  // Operations are typed classes with convenience\n" +
-            "  // builders in Ops.\n" +
-            "  Constant three = ops.constant(3);\n" +
-            "  // Single-result operations implement the Operand\n" +
-            "  // interface, so this works too.\n" +
-            "  Operand four = ops.constant(4);\n" +
-            "  // Most builders are found within a group, and accept\n" +
-            "  // Operand types as operands\n" +
-            "  Operand nine = ops.math().add(four, ops.constant(5));\n" +
-            "  // Multi-result operations however offer methods to\n" +
-            "  // select a particular result for use.\n" +
-            "  Operand result = \n" +
-            "      ops.math().add(ops.array().unique(s, a).y(), b);\n" +
-            "  // Optional attributes\n" +
-            "  ops.math().matMul(a, b, MatMul.transposeA(true));\n" +
-            "  // Naming operators\n" +
-            "  ops.withName(ā€œfooā€).constant(5); // name ā€œfooā€\n" +
-            "  // Names can exist in a hierarchy\n" +
-            "  Ops sub = ops.withSubScope(ā€œsubā€);\n" +
-            "  sub.withName(ā€œbarā€).constant(4); // ā€œsub/barā€\n" +
-            "}\n" +
-            "}
\n", T_GRAPH, T_OPERATOR) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for building a {@link $T} with operation wrappers\n

\n" + + "Any operation wrapper found in the classpath properly annotated as an" + + "{@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n

Example usage:\n

{@code\n"
+                    + "try (Graph g = new Graph()) {\n"
+                    + "  Ops ops = new Ops(g);\n"
+                    + "  // Operations are typed classes with convenience\n"
+                    + "  // builders in Ops.\n"
+                    + "  Constant three = ops.constant(3);\n"
+                    + "  // Single-result operations implement the Operand\n"
+                    + "  // interface, so this works too.\n"
+                    + "  Operand four = ops.constant(4);\n"
+                    + "  // Most builders are found within a group, and accept\n"
+                    + "  // Operand types as operands\n"
+                    + "  Operand nine = ops.math().add(four, ops.constant(5));\n"
+                    + "  // Multi-result operations however offer methods to\n"
+                    + "  // select a particular result for use.\n"
+                    + "  Operand result = \n"
+                    + "      ops.math().add(ops.array().unique(s, a).y(), b);\n"
+                    + "  // Optional attributes\n"
+                    + "  ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+                    + "  // Naming operators\n"
+                    + "  ops.withName(ā€œfooā€).constant(5); // name ā€œfooā€\n"
+                    + "  // Names can exist in a hierarchy\n"
+                    + "  Ops sub = ops.withSubScope(ā€œsubā€);\n"
+                    + "  sub.withName(ā€œbarā€).constant(4); // ā€œsub/barā€\n"
+                    + "}\n"
+                    + "}
\n", + T_GRAPH, + T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withSubScope") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "childScopeName") - .returns(T_OPS) - .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) - .addJavadoc( - "Returns an API that adds operations to the graph with the provided name prefix.\n\n" + - "@see {@link $T#withSubScope(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n" + + "\n@see {@link $T#withSubScope(String)}\n", + T_SCOPE) + .build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withName") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "opName") - .returns(T_OPS) - .addStatement("return new Ops(scope.withName(opName))") - .addJavadoc( - "Returns an API that uses the provided name for an op.\n\n" + - "@see {@link $T#withName(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", + T_SCOPE) + .build()); opsBuilder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); opsBuilder.addMethod( MethodSpec.methodBuilder("scope") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(T_SCOPE) - .addStatement("return scope") - .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); - for (Map.Entry entry: groupToClass.entrySet()) { + for (Map.Entry entry : groupToClass.entrySet()) { opsBuilder.addField( FieldSpec.builder(entry.getValue(), entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .build()); - + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + opsBuilder.addMethod( MethodSpec.methodBuilder(entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(entry.getValue()) - .addStatement("return $L", entry.getKey()) - .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc( + "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); } opsBuilder.addMethod( MethodSpec.methodBuilder("create") - .addModifiers(Modifier.PUBLIC, Modifier.STATIC) - .addParameter(T_GRAPH, "graph") - .returns(T_OPS) - .addStatement("return new Ops(new $T(graph))", T_SCOPE) - .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); return opsBuilder.build(); } @@ -417,12 +437,16 @@ public final class OperatorProcessor extends AbstractProcessor { return am; } } - throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element " - + element.getSimpleName()); + throw new IllegalArgumentException( + "Annotation " + + annotation.getSimpleName() + + " not present on element " + + element.getSimpleName()); } - + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { - for (Map.Entry entry : am.getElementValues().entrySet()) { + for (Map.Entry entry : + am.getElementValues().entrySet()) { if (entry.getKey().getSimpleName().contentEquals(elementName)) { return entry.getValue().getValue().toString(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index d4fd3db5f7325ae891832ff7b658f5d3ea0789a6..7d19696749bbbb944e591daf596562f13f6dc103 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -143,6 +143,82 @@ public final class Graph implements AutoCloseable { } } + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + *

+ * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function + * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. + *

+ * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all + * shapes in {@code y}. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(Output[] y, Output[] x, Output[] dx) { + Output[] dy = new Output[x.length]; + final long[] yHandles = new long[y.length]; + final int[] yIndices = new int[y.length]; + final long[] xHandles = new long[x.length]; + final int[] xIndices = new int[x.length]; + long[] dxHandles = null; + int[] dxIndices = null; + + try (Reference ref = ref()) { + for (int i = 0; i < y.length; ++i) { + yHandles[i] = y[i].op().getUnsafeNativeHandle(); + yIndices[i] = y[i].index(); + } + for (int i = 0; i < x.length; ++i) { + xHandles[i] = x[i].op().getUnsafeNativeHandle(); + xIndices[i] = x[i].index(); + } + if (dx != null && dx.length > 0) { + dxHandles = new long[dx.length]; + dxIndices = new int[dx.length]; + + for (int i = 0; i < dx.length; ++i) { + dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); + dxIndices[i] = dx[i].index(); + } + } + // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles + // of the gradient operations while the second holds the index of their output + // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] + long[] dyHandlesAndIndices = + addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + int ndy = dyHandlesAndIndices.length >> 1; + if (ndy != dy.length) { + throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length + + " were expected"); + } + for (int i = 0, j = ndy; i < ndy; ++i, ++j) { + Operation op = new Operation(this, dyHandlesAndIndices[i]); + dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); + } + } + return dy; + } + + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code dy/dx_1, dy/dx_2...} + *

+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is + * a single output and {@code dx} is null. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(Output y, Output[] x) { + return addGradients(new Output[]{y}, x, null); + } + private final Object nativeHandleLock = new Object(); private long nativeHandle; private int refcount = 0; @@ -254,6 +330,9 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); + private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, + long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + static { TensorFlow.init(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java new file mode 100644 index 0000000000000000000000000000000000000000..f4671c8af941dd732859080238fa48e0a22672b6 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Operands; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + *

+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss + * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}. + *

+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all + * shapes in {@code y}. + *

+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}. + *

+ * Example of usage: + *

{@code
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ * 
+ * Constant alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.dy(1));
+ * }
+ */ +@Operator +public class Gradients implements Op, Iterable> { + + /** + * Optional attributes for {@link Gradients} + */ + public static class Options { + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return this option builder + */ + public Options dx(Iterable> dx) { + this.dx = dx; + return this; + } + + private Iterable> dx; + + private Options() { + } + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * @param scope current graph scope + * @param y outputs of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + public static Gradients create(Scope scope, Iterable> y, Iterable> x, Options... options) { + Output[] dx = null; + if (options != null) { + for (Options opts : options) { + if (opts.dx != null) { + dx = Operands.asOutputs(opts.dx); + } + } + } + Output[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(gradOutputs)); + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is + * a single output. + * + * @param scope current graph scope + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Gradients create(Scope scope, Operand y, Iterable> x, Options... options) { + return create(scope, (Iterable) Arrays.asList(y), x, options); + } + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return builder to add more options to this operation + */ + public Options dx(Iterable> dx) { + return new Options().dx(dx); + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator> iterator() { + return (Iterator) dy.iterator(); + } + + /** + * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x} + */ + public List> dy() { + return dy; + } + + /** + * Returns a symbolic handle to one of the gradient operation output + *

+ * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code + * gradients.dy(0)} + * + * @param The expected element type of the tensors produced by this output. + * @param index The index of the output among the gradients added by this operation + */ + @SuppressWarnings("unchecked") + public Output dy(int index) { + return (Output) dy.get(index); + } + + private List> dy; + + private Gradients(List> dy) { + this.dy = dy; + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/package-info.java index 521c5c610c1f775cf9174664f5b786786ce1181d..f353ee31459806eb2db98d23ac030c15258a77fb 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/package-info.java +++ b/tensorflow/java/src/main/java/org/tensorflow/package-info.java @@ -17,7 +17,7 @@ limitations under the License. * Defines classes to build, save, load and execute TensorFlow models. * *

WARNING: The API is currently experimental and is not covered by TensorFlow API stability + * href="https://www.tensorflow.org/guide/version_semantics">API stability * guarantees. See README.md for installation * instructions. diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index 0fef15527586555e7d3fc2c76403c6e5888fb236..dac6a345e917b618f7f1234c27959069650b51b7 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/java/src/main/native/graph_jni.h" #include +#include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" namespace { @@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { TF_DeleteBuffer(buf); return ret; } + +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, + jlongArray y_handles, jintArray y_indices, + jlongArray x_handles, jintArray x_indices, + jlongArray dx_handles, jintArray dx_indices) { + + TF_Graph* g = requireHandle(env, handle); + if (g == nullptr) return nullptr; + + const jint ny = env->GetArrayLength(y_handles); + const jint nx = env->GetArrayLength(x_handles); + + std::unique_ptr y(new TF_Output[ny]); + std::unique_ptr x(new TF_Output[nx]); + std::unique_ptr dx(nullptr); + std::unique_ptr dy(new TF_Output[nx]); + + resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny); + resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx); + if (dx_handles != nullptr) { + if (env->GetArrayLength(dx_handles) != ny) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d dx handles", ny, + env->GetArrayLength(dx_handles)); + } + dx.reset(new TF_Output[ny]); + resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny); + } + if (env->ExceptionCheck()) return nullptr; + + TF_Status* status = TF_NewStatus(); + TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); + + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return nullptr; + } + TF_DeleteStatus(status); + + // returned array contains both op handles and output indices, in pair + jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1); + jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr); + for (int i = 0, j = nx; i < nx; ++i, ++j) { + TF_Output dy_output = dy.get()[i]; + dy_elems[i] = reinterpret_cast(dy_output.oper); + dy_elems[j] = static_cast(dy_output.index); + } + env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0); + + return dy_handles_and_indices; +} diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index dd2e038332f7d39e6460d6cfef40a9df7e348758..4f87e8d5a79d3ac46f7813ba4344bbfda069b557 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Graph + * Method: name + * Signature: (J[J[I[J[I[J[I)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, + jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, + jintArray); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc index 2cd542d3c9be536a42037e9ef533ed629dd3ac9f..cb54daf13795c24e11566845892da6b5c4896cf5 100644 --- a/tensorflow/java/src/main/native/session_jni.cc +++ b/tensorflow/java/src/main/native/session_jni.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" #include "tensorflow/java/src/main/native/session_jni.h" @@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array, env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT); } -void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, - jintArray src_index, TF_Output* dst, jint n) { - if (env->ExceptionCheck()) return; - jint len = env->GetArrayLength(src_op); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operations", n, len, type); - return; - } - len = env->GetArrayLength(src_index); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operation output indices", n, len, - type); - return; - } - jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); - jint* indices = env->GetIntArrayElements(src_index, nullptr); - for (int i = 0; i < n; ++i) { - if (op_handles[i] == 0) { - throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, - i, n); - break; - } - dst[i] = TF_Output{reinterpret_cast(op_handles[i]), - static_cast(indices[i])}; - } - env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); - env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); -} - void TF_MaybeDeleteBuffer(TF_Buffer* buf) { if (buf == nullptr) return; TF_DeleteBuffer(buf); diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..069ac05a1c39408dc02f5bbf9a7fc50fd095cc96 --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/utils_jni.h" + +#include "tensorflow/java/src/main/native/exception_jni.h" + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n) { + if (env->ExceptionCheck()) return; + jint len = env->GetArrayLength(src_op); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operations", n, len, type); + return; + } + len = env->GetArrayLength(src_index); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operation output indices", n, len, + type); + return; + } + jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); + jint* indices = env->GetIntArrayElements(src_index, nullptr); + for (int i = 0; i < n; ++i) { + if (op_handles[i] == 0) { + throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, + i, n); + break; + } + dst[i] = TF_Output{reinterpret_cast(op_handles[i]), + static_cast(indices[i])}; + } + env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); + env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); +} + + + + diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..352298e7de1d07cebc1a287774c9bef85c9a6ae4 --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_ +#define TENSORFLOW_JAVA_UTILS_JNI_H_ + +#include + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */ diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c540299bdcfcd7bc5969caf82b29144bad24201f..c2e52c22c6dc58a3002b536e64c4607b675804f7 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -129,4 +130,106 @@ public class GraphTest { // expected exception. } } + + @Test + public void addGradientsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x1 = TestUtil.placeholder(g, "x1", Float.class); + Output x2 = TestUtil.placeholder(g, "x2", Float.class); + Output y0 = TestUtil.square(g, "y0", x1); + Output y1 = TestUtil.square(g, "y1", y0); + Output y2 = TestUtil.addN(g, y0, x2); + + Output[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); + + try (Tensor c1 = Tensors.create(3.0f); + Tensor c2 = Tensors.create(2.0f); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientSumsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad[0]) + .run() + .get(0)) { + + assertEquals(114.0f, output.floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientsWithInitialValuesToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run() + .get(0)) { + + assertEquals(108.0f, output.floatValue(), 0.0f); + } + } + } + + private static Output[] toArray(Output... outputs) { + return outputs; + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index e8cc76c2a6458193161a98e17483fe73de107b77..7d5980bcdedebedcd2fa4722e85abc1d598fb4fd 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.util.ArrayList; -import java.util.Collection; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,8 +34,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -53,8 +51,8 @@ public class SessionTest { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -112,7 +110,7 @@ public class SessionTest { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList> outputs = new AutoCloseableList>(result.outputs); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList>(result.outputs); assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -135,8 +133,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.constant(g, "c1", 2718); TestUtil.constant(g, "c2", 31415); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); assertEquals(31415, outputs.get(0).intValue()); assertEquals(2718, outputs.get(1).intValue()); @@ -164,28 +162,6 @@ public class SessionTest { Session s = new Session(g, singleThreadConfigProto())) {} } - private static final class AutoCloseableList extends ArrayList - implements AutoCloseable { - AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } - } - private static byte[] fullTraceRunOptions() { // Ideally this would use the generated Java sources for protocol buffers // and end up with something like the snippet below. However, generating diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index c973b5a3d8b2be8ee21710d65732bc1e5c3b520a..4e848864167982c750b390a77a1ab7f5d0d40fe9 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -16,9 +16,34 @@ limitations under the License. package org.tensorflow; import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Collection; /** Static utility functions. */ public class TestUtil { + + public static final class AutoCloseableList extends ArrayList + implements AutoCloseable { + AutoCloseableList(Collection c) { + super(c); + } + + @Override + public void close() { + Exception toThrow = null; + for (AutoCloseable c : this) { + try { + c.close(); + } catch (Exception e) { + toThrow = e; + } + } + if (toThrow != null) { + throw new RuntimeException(toThrow); + } + } + } + public static Output constant(Graph g, String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name) @@ -36,7 +61,7 @@ public class TestUtil { .output(0); } - public static Output addN(Graph g, Output... inputs) { + public static Output addN(Graph g, Output... inputs) { return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } @@ -58,6 +83,13 @@ public class TestUtil { .setAttr("num_split", numSplit) .build(); } + + public static Output square(Graph g, String name, Output value) { + return g.opBuilder("Square", name) + .addInput(value) + .build() + .output(0); + } public static void transpose_A_times_X(Graph g, int[][] a) { Output aa = constant(g, "A", a); diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a06b536f5b24334963004ef9c321d59dc1685a44..47cf4d6709d75539fda55a473c5f9b7aac97ce65 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4,14 +4,16 @@ # Public targets: # ":platform" - Low-level and platform-specific Python code. -package(default_visibility = [ +visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", "//tensorflow/contrib/lite/toco/python:__pkg__", "//tensorflow_models:__subpackages__", # TODO(aselle): to pass open source test. "//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__", -]) +] + +package(default_visibility = visibility) licenses(["notice"]) # Apache 2.0 @@ -55,12 +57,12 @@ py_library( "//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed - "//tensorflow/tools/api/generator:__pkg__", "//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed ], deps = [ ":no_contrib", "//tensorflow/contrib:contrib_py", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -126,7 +128,6 @@ py_library( ":weights_broadcast_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python/data", - "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras", "//tensorflow/python/ops/distributions", @@ -278,6 +279,9 @@ cc_library( name = "ndarray_tensor_bridge", srcs = ["lib/core/ndarray_tensor_bridge.cc"], hdrs = ["lib/core/ndarray_tensor_bridge.h"], + visibility = visibility + [ + "//learning/deepmind/courier:__subpackages__", + ], deps = [ ":bfloat16_lib", ":numpy_lib", @@ -358,6 +362,9 @@ cc_library( name = "ndarray_tensor", srcs = ["lib/core/ndarray_tensor.cc"], hdrs = ["lib/core/ndarray_tensor.h"], + visibility = visibility + [ + "//learning/deepmind/courier:__subpackages__", + ], deps = [ ":bfloat16_lib", ":ndarray_tensor_bridge", @@ -690,12 +697,22 @@ py_library( ], ) +py_library( + name = "error_interpolation", + srcs = [ + "framework/error_interpolation.py", + ], + srcs_version = "PY2AND3", + deps = [], +) + py_library( name = "function", srcs = ["framework/function.py"], srcs_version = "PY2AND3", deps = [ ":array_ops", + ":cond_v2_impl", ":dtypes", ":framework_ops", ":graph_to_function_def", @@ -712,6 +729,7 @@ py_library( srcs = ["framework/graph_to_function_def.py"], srcs_version = "PY2AND3", deps = [ + ":cond_v2_impl", ":op_def_registry", "//tensorflow/core:protos_all_py", ], @@ -990,6 +1008,18 @@ py_test( ], ) +py_test( + name = "framework_error_interpolation_test", + size = "small", + srcs = ["framework/error_interpolation_test.py"], + main = "framework/error_interpolation_test.py", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":error_interpolation", + ], +) + py_test( name = "framework_subscribe_test", size = "small", @@ -1052,7 +1082,6 @@ tf_gen_op_wrapper_private_py( name = "functional_ops_gen", visibility = [ "//learning/brain/python/ops:__pkg__", - "//tensorflow/contrib/control_flow:__pkg__", ], ) @@ -1600,6 +1629,9 @@ tf_gen_op_wrapper_private_py( tf_gen_op_wrapper_private_py( name = "resource_variable_ops_gen", + visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], ) tf_gen_op_wrapper_private_py( @@ -1827,6 +1859,7 @@ py_library( "tensor_shape", ":array_ops", ":array_ops_gen", + ":cond_v2_impl", ":constant_op", ":control_flow_ops_gen", ":control_flow_util", @@ -1855,6 +1888,37 @@ py_library( ], ) +py_library( + name = "cond_v2", + srcs = [ + "ops/cond_v2.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2_impl", + ":function", + ":function_def_to_graph", + ":gradients", + ], +) + +py_library( + name = "cond_v2_impl", + srcs = [ + "ops/cond_v2_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":c_api_util", + ":framework_ops", + ":functional_ops_gen", + ":pywrap_tensorflow", + ":util", + "//tensorflow/core:protos_all_py", + ], +) + py_library( name = "ctc_ops", srcs = ["ops/ctc_ops.py"], @@ -1937,6 +2001,7 @@ py_library( ":array_grad", ":array_ops", ":bitwise_ops", + ":cond_v2_impl", ":control_flow_grad", ":control_flow_ops", ":control_flow_util", @@ -1953,6 +2018,7 @@ py_library( ":math_grad", ":math_ops", ":platform", + ":random_grad", ":resource_variable_ops", ":spectral_grad", ":util", @@ -2331,6 +2397,19 @@ py_library( ], ) +py_library( + name = "random_grad", + srcs = ["ops/random_grad.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":dtypes", + ":framework_ops", + ":math_ops", + ":random_ops_gen", + ], +) + py_library( name = "random_ops", srcs = ["ops/random_ops.py"], @@ -2391,6 +2470,7 @@ py_library( srcs = ["ops/script_ops.py"], srcs_version = "PY2AND3", deps = [ + ":array_ops", ":framework_for_generated_wrappers", ":script_ops_gen", "//third_party/py/numpy", @@ -3338,6 +3418,19 @@ py_library( ], ) +py_test( + name = "lock_util_test", + size = "small", + srcs = ["util/lock_util_test.py"], + main = "util/lock_util_test.py", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + "@absl_py//absl/testing:parameterized", + ], +) + tf_proto_library( name = "protos_all", srcs = glob( @@ -3656,6 +3749,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":c_api_util", + ":error_interpolation", ":errors", ":framework", ":framework_for_generated_wrappers", @@ -3856,7 +3950,7 @@ tf_cuda_library( tf_py_test( name = "session_test", - size = "small", + size = "medium", srcs = ["client/session_test.py"], additional_deps = [ ":array_ops", @@ -4038,6 +4132,19 @@ py_test( ], ) +py_test( + name = "tf_record_test", + size = "small", + srcs = ["lib/io/tf_record_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":errors", + ":lib", + ":util", + ], +) + cuda_py_test( name = "adam_test", size = "small", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index cf707fb2c731c0db57c2335d3ffd49b292c811cc..a2ab63bb48799d5b93882bb87ab40b02dbb96621 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -79,7 +79,6 @@ from tensorflow.python.ops import initializers_ns as initializers # Bring in subpackages. from tensorflow.python import data from tensorflow.python import keras -from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.layers import layers from tensorflow.python.ops import bitwise_ops as bitwise diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 35aa37ac6dd721750cd72b54d1b8ef6a70402038..f3b788f9319a94756b9b99c1fab190139d22a51b 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1291,7 +1291,7 @@ class BaseSession(SessionInterface): raise type(e)(node_def, op, message) def _extend_graph(self): - with self._graph._lock: # pylint: disable=protected-access + with self._graph._session_run_lock(): # pylint: disable=protected-access tf_session.ExtendSession(self._session) # The threshold to run garbage collection to delete dead tensors. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e49d0671050f557842ad1d3305331d61cd8c9672..b72e029d1ccb688f5992f6cc8695969be5e5e2e3 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import random import os import sys import threading @@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase): for t in threads: t.join() - def testParallelRunAndBuild(self): + @staticmethod + def _build_graph(): + time.sleep(random.random() * 0.1) + # Do some graph construction. Try to exercise non-trivial paths. + graph = ops.get_default_graph() + gdef = None + for _ in range(10): + x = array_ops.placeholder(dtype=dtypes.float32) + with ops.colocate_with(x): + y = array_ops.placeholder(dtype=dtypes.float32) + with ops.device('/cpu:0'): + z = control_flow_ops.while_loop( + lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) + with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): + gradients_impl.gradients(z, [x, y]) + if gdef is None: + gdef = graph.as_graph_def() + else: + importer.import_graph_def(gdef, name='import') + + def testParallelRunAndSingleBuild(self): with session.Session() as sess: c = constant_op.constant(5.0) stop = threading.Event() def run_loop(): while not stop.is_set(): + time.sleep(random.random() * 0.1) self.assertEqual(sess.run(c), 5.0) - threads = [self.checkedThread(target=run_loop) for _ in range(100)] + threads = [self.checkedThread(target=run_loop) for _ in range(10)] for t in threads: t.start() - # Do some graph construction. Try to exercise non-trivial paths. - graph = ops.get_default_graph() - gdef = None - for _ in range(10): - x = array_ops.placeholder(dtype=dtypes.float32) - with ops.colocate_with(x): - y = array_ops.placeholder(dtype=dtypes.float32) - with ops.device('/cpu:0'): - z = control_flow_ops.while_loop( - lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) - with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): - gradients_impl.gradients(z, [x, y]) - if gdef is None: - gdef = graph.as_graph_def() - else: - importer.import_graph_def(gdef, name='import') + SessionTest._build_graph() stop.set() for t in threads: t.join() + def testParallelRunAndParallelBuild(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + stop = threading.Event() + + def run_loop(): + while not stop.is_set(): + time.sleep(random.random() * 0.1) + self.assertEqual(sess.run(c), 5.0) + + run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] + for t in run_threads: + t.start() + + build_threads = [self.checkedThread(target=SessionTest._build_graph) + for _ in range(10)] + for t in build_threads: + t.start() + for t in build_threads: + t.join() + + # Let the run_threads run until the build threads are finished. + stop.set() + for t in run_threads: + t.join() + def testRunFeedDict(self): with session.Session() as s: x = array_ops.zeros([2]) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 1db1432d6521bb5f48558081916158792010b1c5..985cb904360ac293461936bf67fb1b1de2c77b4a 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -135,7 +135,7 @@ tensorflow::ImportNumpy(); // Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers %typemap(out) int64_t { - $result = PyInt_FromLong($1); + $result = PyLong_FromLongLong($1); } // We use TF_OperationGetControlInputs_wrapper instead of @@ -610,7 +610,7 @@ def TF_Reset(target, containers=None, config=None): } for (size_t i = 0; i < $1.size(); ++i) { - PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i])); } } @@ -673,7 +673,7 @@ def TF_Reset(target, containers=None, config=None): } for (size_t i = 0; i < $1.size(); ++i) { - PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i])); } } diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..5f55b228186ca97f72041ef02641750d8ac4b276 --- /dev/null +++ b/tensorflow/python/compat/BUILD @@ -0,0 +1,10 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "compat", + srcs = ["compat.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], +) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..e05ad5544729ca6100ca436936add6340ca19e37 --- /dev/null +++ b/tensorflow/python/compat/compat.py @@ -0,0 +1,81 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for API compatibility between TensorFlow release versions. + +See +@{$guide/version_compat#backward_and_partial_forward_compatibility} +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime + +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 1) + + +def forward_compatible(year, month, day): + """Return true if the forward compatibility window has expired. + + Forward-compatibility refers to scenarios where the producer of a TensorFlow + model (a GraphDef or SavedModel) is compiled against a version of the + TensorFlow library newer than what the consumer was compiled against. The + "producer" is typically a Python program that constructs and trains a model + while the "consumer" is typically another program that loads and serves the + model. + + TensorFlow has been supporting a 3 week forward-compatibility window for + programs compiled from source at HEAD. + + For example, consider the case where a new operation `MyNewAwesomeAdd` is + created with the intent of replacing the implementation of an existing Python + wrapper - `tf.add`. The Python wrapper implementation should change from + something like: + + ```python + def add(inputs, name=None): + return gen_math_ops.add(inputs, name) + ``` + + to: + + ```python + from tensorflow.python.compat import compat + + def add(inputs, name=None): + if compat.forward_compatible(year, month, day): + # Can use the awesome new implementation. + return gen_math_ops.my_new_awesome_add(inputs, name) + # To maintain forward compatibiltiy, use the old implementation. + return gen_math_ops.add(inputs, name) + ``` + + Where `year`, `month`, and `day` specify the date beyond which binaries + that consume a model are expected to have been updated to include the + new operations. This date is typically at least 3 weeks beyond the date + the code that adds the new operation is committed. + + Args: + year: A year (e.g., 2018). + month: A month (1 <= month <= 12) in year. + day: A day (1 <= day <= 31, or 30, or 29, or 28) in month. + + Returns: + True if the caller can expect that serialized TensorFlow graphs produced + can be consumed by programs that are compiled with the TensorFlow library + source code after (year, month, day). + """ + return _FORWARD_COMPATIBILITY_HORIZON > datetime.date(year, month, day) diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py index 7efe0948e7729c398f972977b51426d80b8cd83e..3b9bf2469e6d41fd0e8c5199af677e60bedf93f9 100644 --- a/tensorflow/python/data/__init__.py +++ b/tensorflow/python/data/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """`tf.data.Dataset` API for input pipelines. -See the @{$datasets$Importing Data} Programmer's Guide for an overview. +See @{$guide/datasets$Importing Data} for an overview. """ from __future__ import absolute_import diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index 50bb0837b7052d67ced4fdf5c9c7e96212bdb415..c3d42b49afc0b5674950e4fe8f3048e9ad389796 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -18,9 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from absl.testing import parameterized import numpy as np +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -461,5 +464,55 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): 5, padded_shapes=shape_as_tensor) +class BatchDatasetBenchmark(test.Benchmark): + + def benchmarkBatchSparse(self): + non_zeros_per_row_values = [0, 1, 5, 10, 100] + batch_size_values = [1, 32, 64, 128, 1024] + + sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64) + batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[]) + + dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat( + ).batch(batch_size_placeholder) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + for non_zeros_per_row in non_zeros_per_row_values: + + sparse_value = sparse_tensor.SparseTensorValue( + indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis], + values=np.arange(non_zeros_per_row, dtype=np.int64), + dense_shape=[1000]) + + for batch_size in batch_size_values: + + with session.Session() as sess: + sess.run(iterator.initializer, feed_dict={ + sparse_placeholder: sparse_value, + batch_size_placeholder: batch_size}) + # Run five steps to warm up the session caches before taking the + # first measurement. + for _ in range(5): + sess.run(next_element.indices.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.indices.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100.0 + + print('Batch sparse dataset non-zeros per row: %d batch_size: %d ' + 'wall time: %f' + % (non_zeros_per_row, batch_size, median_wall_time)) + self.report_benchmark( + iters=10000, wall_time=median_wall_time, + name='benchmark_batch_sparse_dataset_nnz_%d_batch_size_%d' % ( + non_zeros_per_row, batch_size)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 768d4ac82cce4e78e9ff493cb5b4401614ecd1c0..0ecd821e9e473522b0cf4bd7bbceb071ecf5bb9e 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -659,6 +659,13 @@ class MapDatasetTest(test.TestCase): break self.assertTrue(found_warning) + def testNestedDatasetError(self): + dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]) + with self.assertRaisesRegexp( + NotImplementedError, r"The Dataset.map\(\) transformation does not " + "currently support nested datasets as outputs."): + _ = dataset.map(dataset_ops.Dataset.from_tensor_slices) + class MapDatasetBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 9e7af878d3d13e80a87b98257e27e4b9f9aab939..7cb6627615461efec074d9ae02ce7dd4c57f86b9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -218,7 +218,7 @@ class Dataset(object): @{tf.constant} operations. For large datasets (> 1 GB), this can waste memory and run into byte limits of graph serialization. If tensors contains one or more large NumPy arrays, consider the alternative described in - @{$programmers_guide/datasets#consuming_numpy_arrays$this guide}. + @{$guide/datasets#consuming_numpy_arrays$this guide}. Args: tensors: A nested structure of tensors. @@ -237,7 +237,7 @@ class Dataset(object): @{tf.constant} operations. For large datasets (> 1 GB), this can waste memory and run into byte limits of graph serialization. If tensors contains one or more large NumPy arrays, consider the alternative described in - @{$programmers_guide/datasets#consuming_numpy_arrays$this guide}. + @{$guide/datasets#consuming_numpy_arrays$this guide}. Args: tensors: A nested structure of tensors, each having the same size in the @@ -809,11 +809,12 @@ class Dataset(object): def batch(self, batch_size, drop_remainder=False): """Combines consecutive elements of this dataset into batches. - NOTE: If the number of elements (`N`) in this dataset is not an exact - multiple of `batch_size`, the final batch contain smaller tensors with - shape `N % batch_size` in the batch dimension. If your program depends on - the batches having the same shape, consider using the - @{tf.contrib.data.batch_and_drop_remainder} transformation instead. + The tensors in the resulting element will have an additional outer + dimension, which will be `batch_size` (or `N % batch_size` for the last + element if `batch_size` does not divide the number of input elements `N` + evenly and `drop_remainder` is `False`). If your program depends on the + batches having the same outer dimension, you should set the `drop_remainder` + argument to `True` to prevent the smaller batch from being produced. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of @@ -836,13 +837,19 @@ class Dataset(object): """Combines consecutive elements of this dataset into padded batches. This transformation combines multiple consecutive elements of the input - dataset into a single element. Like @{tf.data.Dataset.batch}, the tensors - in the resulting element have an additional outer dimension, which will be - `batch_size` for all but the last element, and `N % batch_size` for the - last element (where `N` is the number of elements in this dataset). Unlike - @{tf.data.Dataset.batch}, the elements may have different shapes for some - of their components, and this transformation will pad each component to - the respective shape in `padding_shapes`. The `padding_shapes` argument + dataset into a single element. + + Like @{tf.data.Dataset.batch}, the tensors in the resulting element will + have an additional outer dimension, which will be `batch_size` (or + `N % batch_size` for the last element if `batch_size` does not divide the + number of input elements `N` evenly and `drop_remainder` is `False`). If + your program depends on the batches having the same outer dimension, you + should set the `drop_remainder` argument to `True` to prevent the smaller + batch from being produced. + + Unlike @{tf.data.Dataset.batch}, the input elements to be batched may have + different shapes, and this transformation will pad each component to the + respective shape in `padding_shapes`. The `padding_shapes` argument determines the resulting shape for each dimension of each component in an output element: @@ -852,12 +859,6 @@ class Dataset(object): will be padded out to the maximum length of all elements in that dimension. - NOTE: If the number of elements (`N`) in this dataset is not an exact - multiple of `batch_size`, the final batch contain smaller tensors with - shape `N % batch_size` in the batch dimension. If your program depends on - the batches having the same shape, consider using the - @{tf.contrib.data.padded_batch_and_drop_remainder} transformation instead. - See also @{tf.contrib.data.dense_to_sparse_batch}, which combines elements that may have different shapes into a @{tf.SparseTensor}. @@ -1148,13 +1149,74 @@ class SparseTensorSliceDataset(Dataset): return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64) +class _NestedDatasetComponent(object): + """The structure of a `Dataset` nested in a component of another `Dataset`. + + A `StructuredFunctionWrapper` around a function that returns a `Dataset` as + one of its components will have a `NestedDatasetComponent` in the + corresponding position in the `output_classes`, `output_shapes`, and + `output_types` properties. + + NOTE(mrry): This class is not currently exposed via the public API. Support + for nested datasets can be enabled on a function-by-function basis by setting + `experimental_nested_dataset_support=True` in the `StructuredFunctionWrapper` + initializer. + + TODO(b/110122868): Add this class, or something equivalent, to the public API. + We are considering revising the public API for accessing Dataset structure + (`output_classes` etc.) based on experience with nested datasets and other + custom component types. + """ + + def __init__(self, dataset): + self._output_classes = dataset.output_classes + self._output_shapes = dataset.output_shapes + self._output_types = dataset.output_types + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +class _VariantDataset(Dataset): + """A Dataset wrapper around a @{tf.variant}-typed function argument.""" + + def __init__(self, dataset_variant, structure): + super(_VariantDataset, self).__init__() + self._dataset_variant = dataset_variant + self._structure = structure + + def _as_variant_tensor(self): + return self._dataset_variant + + @property + def output_classes(self): + return self._structure.output_classes + + @property + def output_shapes(self): + return self._structure.output_shapes + + @property + def output_types(self): + return self._structure.output_types + + class StructuredFunctionWrapper(object): """A wrapper for `Defun` that supports structured arguments and return values. """ def __init__(self, func, transformation_name, dataset=None, input_classes=None, input_shapes=None, input_types=None, - add_to_graph=True): + add_to_graph=True, experimental_nested_dataset_support=False): """Creates a new `StructuredFunctionWrapper` for the given function. Args: @@ -1173,6 +1235,8 @@ class StructuredFunctionWrapper(object): argument defines the element types and structure for `func` arguments. add_to_graph: (Optional.) If `True`, the function will be added to the default graph. + experimental_nested_dataset_support: (Optional.) If `True`, the function + will support @{tf.data.Dataset} objects as arguments and return values. Raises: ValueError: If an invalid combination of `dataset`, `input_classes`, @@ -1194,14 +1258,37 @@ class StructuredFunctionWrapper(object): self._input_types = dataset.output_types self._input_classes = dataset.output_classes - @function.Defun(*defun_args( - input_types=self._input_types, input_classes=self._input_classes)) + self._transformation_name = transformation_name + + # TODO(b/110122868): Enable this support for all `tf.data` functions. + self._nested_dataset_support = experimental_nested_dataset_support + + @function.Defun(*self._defun_args()) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" - nested_args = restructure_args(args, - input_shapes=self._input_shapes, - input_types=self._input_types, - input_classes=self._input_classes) + flat_args = [] + for arg, arg_class, arg_shape, arg_type in zip( + args, + nest.flatten(self._input_classes), + nest.flatten(self._input_shapes), + nest.flatten(self._input_types)): + # TODO(b/110122868): Add a registration mechanism for new component + # types. + if arg_class is sparse_tensor_lib.SparseTensor: + arg = sparse.deserialize_sparse_tensors( + arg, arg_type, arg_shape, arg_class) + arg.indices.set_shape([None, arg_shape.ndims]) + arg.dense_shape.set_shape([arg_shape.ndims]) + elif isinstance(arg_class, _NestedDatasetComponent): + assert self._nested_dataset_support + arg = _VariantDataset(arg, arg_class) + else: + arg.set_shape(arg_shape) + flat_args.append(arg) + nested_args = nest.pack_sequence_as(self._input_classes, flat_args) + if not _should_unpack_args(nested_args): + nested_args = (nested_args,) + ret = func(*nested_args) # If `func` returns a list of tensors, `nest.flatten()` and # `ops.convert_to_tensor()` would conspire to attempt to stack @@ -1218,24 +1305,45 @@ class StructuredFunctionWrapper(object): # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) + flat_ret = [] + flat_classes = [] + flat_shapes = [] + flat_types = [] + for t in nest.flatten(ret): + # TODO(b/110122868): Add a registration mechanism for new component + # types. + if sparse_tensor_lib.is_sparse(t): + t = sparse_tensor_lib.SparseTensor.from_value(t) + flat_ret.append(sparse.serialize_sparse_tensors(t)) + flat_classes.append(sparse_tensor_lib.SparseTensor) + flat_shapes.append(t.get_shape()) + flat_types.append(t.dtype) + elif isinstance(t, Dataset): + if not self._nested_dataset_support: + raise NotImplementedError( + "The %s transformation does not currently support nested " + "datasets as outputs." % self._transformation_name) + + flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access + component = _NestedDatasetComponent(t) + flat_classes.append(component) + flat_shapes.append(component) + flat_types.append(component) + else: + t = ops.convert_to_tensor(t) + flat_ret.append(t) + flat_classes.append(ops.Tensor) + flat_shapes.append(t.get_shape()) + flat_types.append(t.dtype) - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) + ret = nest.pack_sequence_as(ret, flat_ret) + self._output_classes = nest.pack_sequence_as(ret, flat_classes) + self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) + self._output_types = nest.pack_sequence_as(ret, flat_types) _warn_if_collections(transformation_name) - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) + return flat_ret self._function = tf_data_structured_function_wrapper if add_to_graph: @@ -1246,6 +1354,25 @@ class StructuredFunctionWrapper(object): # in case (e.g.) we need to rerun the function. self._function._create_definition_if_needed() # pylint: disable=protected-access + def _defun_args(self): + """Returns a flat list of @{tf.DType} for the input element structure.""" + ret = [] + for input_type, input_class in zip(nest.flatten(self._input_types), + nest.flatten(self._input_classes)): + # TODO(b/110122868): Add a registration mechanism for new component types. + if input_class is sparse_tensor_lib.SparseTensor: + ret.append(dtypes.variant) + elif isinstance(input_class, _NestedDatasetComponent): + if not self._nested_dataset_support: + raise NotImplementedError( + "The %s transformation does not currently support nested " + "datasets as inputs." % self._transformation_name) + ret.append(dtypes.variant) + else: + assert isinstance(input_type, dtypes.DType) + ret.append(input_type) + return ret + @property def output_classes(self): return self._output_classes @@ -1287,109 +1414,6 @@ def flat_structure(dataset): } -# TODO(mrry): Investigate adding a `Defun` wrapper that combines -# `defun_args()`, `restructure_args()`, and a future helper that consumes the -# outputs of the wrapped function. -def defun_args(dataset=None, input_types=None, input_classes=None): - """Returns a flat list of @{tf.DType} for a given element structure. - - The expected usage for an example function is as follows: - - ```python - input_dataset = ... # A `tf.data.Dataset`. - - @function.Defun(*defun_args(input_dataset)) - def tf_example_func(*args): - nested_args = restructure_args(args, input_dataset) - # [Destructure and handle the return values from `example_func()`. - ``` - - Either `dataset`, or both of `input_types` and `input_classes` must be - specified. If `dataset` is not specified, the structures of `input_types` and - `input_classes` must be compatible. - - Args: - dataset: (Optional.) A @{tf.data.Dataset} whose element structure should - be flattened. - input_types: (Optional.) A nested structure of @{tf.DType} with the desired - structure and types for each argument. - input_classes: (Optional.) A nested structure of `type` with the desired - structure and classes for each argument. - - Returns: - A flat list of @{tf.DType} for the given element structure. - """ - if input_types is None: - assert dataset is not None - assert input_classes is None - input_types = dataset.output_types - input_classes = dataset.output_classes - else: - assert input_types is not None and input_classes is not None - return nest.flatten( - sparse.as_dense_types(input_types, input_classes)) - - -def restructure_args(args, dataset=None, input_shapes=None, input_types=None, - input_classes=None): - """Converts a flat tuple of arguments into a given structure. - - The intended use is to bridge between the flat tuple of unshaped @{tf.Tensor} - arguments that a `Defun` receives and the potentially nested structures that - `tf.data` functions expect. - - The expected usage for an example function is as follows: - - ```python - input_dataset = ... # A `tf.data.Dataset`. - - @function.Defun(*defun_args(input_dataset)) - def tf_example_func(*args): - nested_args = restructure_args(args, input_dataset) - ret = example_func(*nested_args) - # [Destructure and handle the return values from `example_func()`. - ``` - - Either `dataset`, or all of `input_shapes`, `input_types` and `input_classes` - must be specified. If `dataset` is not specified, the structures of - `input_shapes`, `input_types` and `input_classes` must be compatible. - - Args: - args: A flat tuple of @{tf.Tensor} objects, representing the arguments - to a TensorFlow function. - dataset: (Optional.) A @{tf.data.Dataset} whose element structure matches - the desired structure of the arguments. - input_shapes: (Optional.) A nested structure of @{tf.TensorShape} with the - desired structure and static shapes for each argument. - input_types: (Optional.) A nested structure of @{tf.DType} with the desired - structure and types for each argument. - input_classes: (Optional.) A nested structure of `type` with the desired - structure and classes for each argument. - - Returns: - A nested structure representing the arguments. - """ - if input_shapes is None: - assert dataset is not None - assert input_types is None and input_classes is None - input_shapes = dataset.output_shapes - input_types = dataset.output_types - input_classes = dataset.output_classes - else: - assert input_types is not None and input_classes is not None - - dense_shapes = sparse.as_dense_shapes(input_shapes, input_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_classes, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_types, input_shapes, input_classes) - if not _should_unpack_args(nested_args): - nested_args = (nested_args,) - return nested_args - - class _GeneratorDataset(Dataset): """A `Dataset` that generates elements by invoking a function.""" @@ -2105,19 +2129,14 @@ class FlatMapDataset(Dataset): super(FlatMapDataset, self).__init__() self._input_dataset = input_dataset - # TODO(b/110122868): When we handle nested datasets natively as the return - # value from `map_func`, we can avoid needing this wrapper. - def map_func_wrapper(*args): - dataset = map_func(*args) - if not isinstance(dataset, Dataset): - raise TypeError("`map_func` must return a `Dataset` object.") - self._output_classes = dataset.output_classes - self._output_shapes = dataset.output_shapes - self._output_types = dataset.output_types - return dataset._as_variant_tensor() # pylint: disable=protected-access - wrapped_func = StructuredFunctionWrapper( - map_func_wrapper, self._transformation_name(), input_dataset) + map_func, self._transformation_name(), input_dataset, + experimental_nested_dataset_support=True) + if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent): + raise TypeError("`map_func` must return a `Dataset` object.") + self._output_classes = wrapped_func.output_classes.output_classes + self._output_types = wrapped_func.output_types.output_types + self._output_shapes = wrapped_func.output_shapes.output_shapes self._map_func = wrapped_func.function def _as_variant_tensor(self): diff --git a/tensorflow/python/data/util/random_seed_test.py b/tensorflow/python/data/util/random_seed_test.py index 33227e82afe6fe1c748693d107d4e9844abb8e09..a809151e6ef57de8a39806b8164f818d94b8a783 100644 --- a/tensorflow/python/data/util/random_seed_test.py +++ b/tensorflow/python/data/util/random_seed_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class RandomSeedTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandomSeed(self): zero_t = constant_op.constant(0, dtype=dtypes.int64, name='zero') one_t = constant_op.constant(1, dtype=dtypes.int64, name='one') diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 09062abd7446628ede12e782e202ee0e55905879..c025dc8aa58a500ace3e28ba4528abd4f4c38ba7 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -5,7 +5,7 @@ # # ":debug_py": Public Python methods and classes of tfdbg. # For API documentation, see https://www.tensorflow.org/api_docs/python/tfdbg -# For a user interface walkthrough, see https://www.tensorflow.org/programmers_guide/debugger +# For a user interface walkthrough, see https://www.tensorflow.org/guide/debugger # ":grpc_debug_server": Server interface for grpc:// debug URLs. package( @@ -167,6 +167,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:platform", + "//third_party/py/numpy", "@six_archive//:six", ], ) @@ -453,6 +454,17 @@ py_binary( ], ) +py_binary( + name = "debug_keras", + srcs = ["examples/debug_keras.py"], + srcs_version = "PY2AND3", + deps = [ + ":debug_py", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + py_test( name = "common_test", size = "small", @@ -802,6 +814,7 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) @@ -1084,6 +1097,7 @@ py_test( "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) @@ -1094,6 +1108,7 @@ sh_test( data = [ ":debug_errors", ":debug_fibonacci", + ":debug_keras", ":debug_mnist", ":debug_tflearn_iris", ":offline_analyzer", diff --git a/tensorflow/python/debug/README.md b/tensorflow/python/debug/README.md index 269bbb19bdb898d1d81d0b9c618a284a437e68b9..9c16af4d79754cee5d77158d5c2466412c6b9e68 100644 --- a/tensorflow/python/debug/README.md +++ b/tensorflow/python/debug/README.md @@ -28,7 +28,7 @@ models: * Easy access through session wrappers * Easy integration with common high-level APIs, such as - [TensorFlow Estimators](https://www.tensorflow.org/programmers_guide/estimators) and + [TensorFlow Estimators](https://www.tensorflow.org/guide/estimators) and [Keras](https://keras.io/) * Inspection of runtime tensor values and node connections * Conditional breaking after runs that generate tensors satisfying given @@ -43,7 +43,7 @@ models: ## How to use TFDBG? -* For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/programmers_guide/debugger. +* For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/guide/debugger. * For information on the web GUI of TFDBG (TensorBoard Debugger Plugin), see [this README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md). * For programmatic use of the API of TFDBG, see https://www.tensorflow.org/api_docs/python/tfdbg. diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py index 12e79ab07a4655c7d41f41d2e71906273e154a08..02563fde845e7951046a8bcd65899ef5e1fcc35f 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common.py +++ b/tensorflow/python/debug/cli/debugger_cli_common.py @@ -23,9 +23,11 @@ import re import sre_constants import traceback +import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python import pywrap_tensorflow_internal from tensorflow.python.platform import gfile HELP_INDENT = " " @@ -131,6 +133,25 @@ def rich_text_lines_from_rich_line_list(rich_text_list, annotations=None): return RichTextLines(lines, font_attr_segs, annotations=annotations) +def get_tensorflow_version_lines(include_dependency_versions=False): + """Generate RichTextLines with TensorFlow version info. + + Args: + include_dependency_versions: Include the version of TensorFlow's key + dependencies, such as numpy. + + Returns: + A formatted, multi-line `RichTextLines` object. + """ + lines = ["TensorFlow version: %s" % pywrap_tensorflow_internal.__version__] + lines.append("") + if include_dependency_versions: + lines.append("Dependency version(s):") + lines.append(" numpy: %s" % np.__version__) + lines.append("") + return RichTextLines(lines) + + class RichTextLines(object): """Rich multi-line text. @@ -538,6 +559,8 @@ class CommandHandlerRegistry(object): HELP_COMMAND = "help" HELP_COMMAND_ALIASES = ["h"] + VERSION_COMMAND = "version" + VERSION_COMMAND_ALIASES = ["ver"] def __init__(self): # A dictionary from command prefix to handler. @@ -562,6 +585,13 @@ class CommandHandlerRegistry(object): "Print this help message.", prefix_aliases=self.HELP_COMMAND_ALIASES) + # Register a default handler for the command "version". + self.register_command_handler( + self.VERSION_COMMAND, + self._version_handler, + "Print the versions of TensorFlow and its key dependencies.", + prefix_aliases=self.VERSION_COMMAND_ALIASES) + def register_command_handler(self, prefix, handler, @@ -763,6 +793,11 @@ class CommandHandlerRegistry(object): else: return RichTextLines(["ERROR: help takes only 0 or 1 input argument."]) + def _version_handler(self, args, screen_info=None): + del args # Unused currently. + del screen_info # Unused currently. + return get_tensorflow_version_lines(include_dependency_versions=True) + def _resolve_prefix(self, token): """Resolve command prefix from the prefix itself or its alias. diff --git a/tensorflow/python/debug/cli/debugger_cli_common_test.py b/tensorflow/python/debug/cli/debugger_cli_common_test.py index 1b7a5962fe7dc4e19446c3e3b0aeab672eb30f1f..aba95e5820b1d8c6b3811fc69328317ce2c3ac64 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common_test.py +++ b/tensorflow/python/debug/cli/debugger_cli_common_test.py @@ -21,6 +21,9 @@ import os import stat import tempfile +import numpy as np + +from tensorflow.python import pywrap_tensorflow_internal from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.framework import test_util from tensorflow.python.platform import gfile @@ -547,7 +550,10 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase): " Show screen width in number of columns.", "", "", "help", " Aliases: h", "", " Print this help message.", "", "", "noop", " Aliases: n, NOOP", "", - " No operation.", " I.e., do nothing.", "", ""], + " No operation.", " I.e., do nothing.", "", "", + "version", " Aliases: ver", "", + " Print the versions of TensorFlow and its key " + "dependencies.", "", ""], output.lines) # Get help for one specific command prefix. @@ -575,7 +581,9 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase): self.assertEqual(help_intro.lines + [ "help", " Aliases: h", "", " Print this help message.", "", "", "noop", " Aliases: n, NOOP", "", " No operation.", - " I.e., do nothing.", "", "" + " I.e., do nothing.", "", "", + "version", " Aliases: ver", "", + " Print the versions of TensorFlow and its key dependencies.", "", "" ], output.lines) @@ -1147,5 +1155,22 @@ class MenuTest(test_util.TensorFlowTestCase): self.assertEqual((40, 50, ["bold"]), output.font_attr_segs[0][2]) +class GetTensorFlowVersionLinesTest(test_util.TensorFlowTestCase): + + def testGetVersionWithoutDependencies(self): + out = debugger_cli_common.get_tensorflow_version_lines() + self.assertEqual(2, len(out.lines)) + self.assertEqual( + "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__, + out.lines[0]) + + def testGetVersionWithDependencies(self): + out = debugger_cli_common.get_tensorflow_version_lines(True) + self.assertIn( + "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__, + out.lines) + self.assertIn(" numpy: %s" % np.__version__, out.lines) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/examples/README.md b/tensorflow/python/debug/examples/README.md index cb4d484092fe39698de1ff11e4d50d4879960e0c..3b431e04dc3565037dc018991bea68ab019e8af0 100644 --- a/tensorflow/python/debug/examples/README.md +++ b/tensorflow/python/debug/examples/README.md @@ -3,7 +3,7 @@ Hi, there! The documentation of **TensorFlow Debugger (tfdbg)** has moved. See the source version at -[this new location](../../../docs_src/programmers_guide/debugger.md). +[this new location](../../../docs_src/guide/debugger.md). See the public website version at -[https://www.tensorflow.org/programmers_guide/debugger](https://www.tensorflow.org/programmers_guide/debugger). +[https://www.tensorflow.org/guide/debugger](https://www.tensorflow.org/guide/debugger). diff --git a/tensorflow/python/debug/examples/debug_keras.py b/tensorflow/python/debug/examples/debug_keras.py new file mode 100644 index 0000000000000000000000000000000000000000..3272d85ade957b254b2c1a0977156179cd71bb9d --- /dev/null +++ b/tensorflow/python/debug/examples/debug_keras.py @@ -0,0 +1,89 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tfdbg example: debugging tf.keras models training on tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf + +from tensorflow.python import debug as tf_debug + + +def main(_): + # Create a dummy dataset. + num_examples = 8 + steps_per_epoch = 2 + input_dims = 3 + output_dims = 1 + xs = np.zeros([num_examples, input_dims]) + ys = np.zeros([num_examples, output_dims]) + dataset = tf.data.Dataset.from_tensor_slices( + (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch)) + + sess = tf.Session() + if FLAGS.debug: + # Use the command-line interface (CLI) of tfdbg. + sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) + elif FLAGS.tensorboard_debug_address: + # Use the TensorBoard Debugger Plugin (GUI of tfdbg). + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, FLAGS.tensorboard_debug_address) + tf.keras.backend.set_session(sess) + + # Create a dummy model. + model = tf.keras.Sequential([ + tf.keras.layers.Dense(1, input_shape=[input_dims])]) + model.compile(loss="mse", optimizer="sgd") + + # Train the model using the dummy dataset created above. + model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--debug", + type="bool", + nargs="?", + const=True, + default=False, + help="Use debugger to track down bad values during training. " + "Mutually exclusive with the --tensorboard_debug_address flag.") + parser.add_argument( + "--ui_type", + type=str, + default="curses", + help="Command-line user interface type (curses | readline).") + parser.add_argument( + "--tensorboard_debug_address", + type=str, + default=None, + help="Connect to the TensorBoard Debugger Plugin backend specified by " + "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " + "--debug flag.") + parser.add_argument( + "--epochs", + type=int, + default=2, + help="Number of epochs to train the model for.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh index e9c45a7e6e92d069f51648647620f7a7c3a5aadc..2d35b2d8bb10d17decfa404afd5004d3409c06e5 100755 --- a/tensorflow/python/debug/examples/examples_test.sh +++ b/tensorflow/python/debug/examples/examples_test.sh @@ -48,12 +48,14 @@ if [[ -z "${PYTHON_BIN_PATH}" ]]; then DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors" DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist" DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris" + DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras" OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer" else DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_fibonacci" DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_errors" DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_mnist" DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_tflearn_iris" + DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_keras" OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer" fi @@ -96,6 +98,11 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then exit 1 fi +# Test debugging of tf.keras. +cat << EOF | "${DEBUG_KERAS_BIN}" --debug --ui_type=readline +run -f has_inf_or_nan +EOF + # Test offline_analyzer. echo echo "Testing offline_analyzer" diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py index bd00f738610627a4b3bc7c61476164188a7b460c..676097fde95e2e5a685e8e43f8f38d3e62e7084a 100644 --- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py +++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py @@ -44,7 +44,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase): def _no_rewrite_session_config(self): rewriter_config = rewriter_config_pb2.RewriterConfig( - dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index c530204bbf6959f56a72c6e67add91f1e575f067..b9524ce649c7d6d888affacc22cfadd41dbe2e40 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface): self._default_session_context_manager = None + # A cache for callables created from CallableOptions. + self._cached_callables_from_options = dict() + @property def graph(self): return self._sess.graph @@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): """Wrapper around Session.run() that inserts tensor watch options. Args: @@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. callable_runner: A `callable` returned by `Session.make_callable()`. If not `None`, `fetches` and `feed_dict` must both be `None`. - callable_runner_args: An optional list of arguments to `callable_runner`. + Mutually exclusive with `callable_options`. + callable_runner_args: An optional list of arguments to `callable_runner` + or for `callable_options`. + callable_options: An instance of `config_pb2.CallableOptions`, to be + used with `Session._make_callable_from_options()`. Mutually exclusive + with `callable_runner`. Returns: Simply forwards the output of the wrapped `Session.run()` call. @@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface): ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` is not `None` and either or both of `fetches` and `feed_dict` is `None`. """ - if not callable_runner: + if callable_runner and callable_options: + raise ValueError( + "callable_runner and callable_options are mutually exclusive, but " + "are both specified in this call to BaseDebugWrapperSession.run().") + + if not (callable_runner or callable_options): self.increment_run_call_count() - else: - if fetches or feed_dict: - raise ValueError( - "callable_runner and fetches/feed_dict are mutually exclusive, but " - "are used simultaneously.") + elif callable_runner and (fetches or feed_dict): + raise ValueError( + "callable_runner and fetches/feed_dict are mutually exclusive, " + "but are used simultaneously.") empty_fetches = not nest.flatten(fetches) if empty_fetches: @@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface): if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) + elif callable_options: + # pylint:disable=protected-access + return self._sess._make_callable_from_options( + callable_options)(*callable_runner_args) + # pylint:enable=protected-access else: return self._sess.run(fetches, feed_dict=feed_dict, @@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface): if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. - decorated_run_options = options or config_pb2.RunOptions() + decorated_run_options = None + if callable_options: + callable_options_id = id(callable_options) + if callable_options_id not in self._cached_callables_from_options: + # Make a copy of callable_options to avoid mutating it. + new_callable_options = config_pb2.CallableOptions() + new_callable_options.CopyFrom(callable_options) + decorated_run_options = new_callable_options.run_options + else: + decorated_run_options = options or config_pb2.RunOptions() + run_metadata = run_metadata or config_pb2.RunMetadata() - self._decorate_run_options_for_debug( - decorated_run_options, - run_start_resp.debug_urls, - debug_ops=run_start_resp.debug_ops, - node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, - op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=( - run_start_resp.tensor_dtype_regex_whitelist), - tolerate_debug_op_creation_failures=( - run_start_resp.tolerate_debug_op_creation_failures)) + if decorated_run_options: + self._decorate_run_options_for_debug( + decorated_run_options, + run_start_resp.debug_urls, + debug_ops=run_start_resp.debug_ops, + node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, + op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, + tensor_dtype_regex_whitelist=( + run_start_resp.tensor_dtype_regex_whitelist), + tolerate_debug_op_creation_failures=( + run_start_resp.tolerate_debug_op_creation_failures)) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. @@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface): retvals = callable_runner(*callable_runner_args, options=decorated_run_options, run_metadata=run_metadata) + elif callable_options: + # pylint:disable=protected-access + if callable_options_id in self._cached_callables_from_options: + callable_object = self._cached_callables_from_options[ + callable_options_id] + else: + callable_object = self._sess._make_callable_from_options( + new_callable_options) + self._cached_callables_from_options[ + callable_options_id] = callable_object + # pylint:enable=protected-access + retvals = callable_object( + *callable_runner_args, run_metadata=run_metadata) else: retvals = self._sess.run(fetches, feed_dict=feed_dict, @@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata=kwargs.get("run_metadata", None), callable_runner=runner, callable_runner_args=runner_args) + return wrapped_runner + def _make_callable_from_options(self, callable_options): + def wrapped_runner(*feed_values, **kwargs): + return self.run(None, + run_metadata=kwargs.get("run_metadata", None), + callable_options=callable_options, + callable_runner_args=feed_values) return wrapped_runner @property diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 1f9c8fa5a96b4d6826fae0870608e0e737c7cd88..85944fa61118114cc73f9288f3f974f0a5a8a839 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -215,7 +215,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): if self._send_traceback_and_source_code: self._sent_graph_version = publish_traceback( self._grpc_debug_server_urls, self.graph, feed_dict, fetches, @@ -226,4 +227,5 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=options, run_metadata=run_metadata, callable_runner=callable_runner, - callable_runner_args=callable_runner_args) + callable_runner_args=callable_runner_args, + callable_options=callable_options) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index c8625655e51a43a222addedd4beecdd3515d7fb6..668ffb57f10a69ce7e11e889fe613afbd618e823 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -290,6 +290,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if self._run_call_count == 1: # Show logo at the onset of the first run. help_intro.extend(cli_shared.get_tfdbg_logo()) + help_intro.extend(debugger_cli_common.get_tensorflow_version_lines()) help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:")) help_intro.extend(self._run_info) @@ -466,6 +467,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if self._run_call_count == 1: output.extend(cli_shared.get_tfdbg_logo()) + output.extend(debugger_cli_common.get_tensorflow_version_lines()) output.extend(self._run_info) if (not self._is_run_start and @@ -594,7 +596,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): # Register tab completion for the filter names. curses_cli.register_tab_comp_context(["run", "r"], list(self._tensor_filters.keys())) - if self._feed_dict: + if self._feed_dict and hasattr(self._feed_dict, "keys"): # Register tab completion for feed_dict keys. feed_keys = [common.get_graph_element_name(key) for key in self._feed_dict.keys()] diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index b06fa26a935b42709575f8e400e0bda951ffbbc7..05c9eaa4d27319ecf5e12fdeb0a973246c61704a 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -21,7 +21,10 @@ import os import shutil import tempfile +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import debugger_cli_common @@ -149,7 +152,13 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): dtypes.float32, shape=([5, 5]), name="sparse_placeholder") self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph) - self.sess = session.Session() + rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + config_proto = config_pb2.ConfigProto(graph_options=graph_options) + self.sess = session.Session(config=config_proto) # Initialize variable. self.sess.run(variables.global_variables_initializer()) @@ -393,6 +402,113 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertAllClose(42.0, tensor_runner(41.0, 1.0)) self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"])) + def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + for _ in range(2): + callable_output = sess_callable() + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + + def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + a = math_ops.add(ph1, ph1, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array([10.5, -10.5], dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value) + self.assertAllClose( + np.array([42.0, -42.0], dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2") + a = math_ops.add(ph1, ph2, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.feed.append("callable_ph2") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array(5.0, dtype=np.float32) + ph2_value = np.array(16.0, dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value, ph2_value) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE + + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + run_metadata = config_pb2.RunMetadata() + # Call the callable with a custom run_metadata. + callable_output = sess_callable(run_metadata=run_metadata) + # Verify that step_stats is populated in the custom run_metadata. + self.assertTrue(run_metadata.step_stats) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(1, len(debug_dumps)) + debug_dump = debug_dumps[0] + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + def testRuntimeErrorShouldBeCaught(self): wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index e8a7904a88f00b1466edcbfc627509518cc02b07..6ede8e4f4d9c549faae3223d400d25b7712bbc74 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/python:numpy_lib", "//tensorflow/python:py_seq_tensor", "//tensorflow/python:safe_ptr", + "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index bd97b181ff7fa5a38ea8ab16e55b3ade7b599261..3e3c82e56a8c957839e420550bfb073d400b4a77 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -605,7 +605,9 @@ def _zeros(shape, dtype): # TODO(apassos): need to save enough information about variant tensors to do # a zeros return None - cache_key = shape, dtype, device + # pylint: disable=protected-access + cache_key = shape, dtype, device, context.context()._eager_context.mode + # pylint: enable=protected-access cached = _zeros_cache.get(cache_key) if cached is None: cached = _fast_fill(0, shape, dtype) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 826c6683b9668ab892883119a533ee8d497d7b58..ebbd3cd98e892fddb556fc95a4292e05d16fc167 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -46,7 +46,7 @@ from tensorflow.python.training import training class BackpropTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAggregateGradients(self): def fn(x): @@ -251,7 +251,7 @@ class BackpropTest(test.TestCase): g, = backprop.gradients_function(loss, [0])(logits, labels) self.assertAllEqual(g.numpy(), [[-0.5, 0.5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientWithinTapeBlock(self): v1 = resource_variable_ops.ResourceVariable(1.) self.evaluate(v1.initializer) @@ -265,7 +265,7 @@ class BackpropTest(test.TestCase): grad = t.gradient(loss, v1) self.assertAllEqual(self.evaluate(grad), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestedSelfContexts(self): v1 = resource_variable_ops.ResourceVariable(1.) self.evaluate(v1.initializer) @@ -435,7 +435,7 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=False) as g: x = constant_op.constant(3.0) @@ -445,7 +445,7 @@ class BackpropTest(test.TestCase): self.assertEqual(self.evaluate(grad), [2.0, 2.0]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -459,7 +459,7 @@ class BackpropTest(test.TestCase): self.assertEqual(self.evaluate(grad), [3.0, 11.0]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeStructure(self): with backprop.GradientTape(persistent=True) as g: # Using different constant values because constant tensors are @@ -482,7 +482,7 @@ class BackpropTest(test.TestCase): [1.0, {'x2': 2.0, 'x3': 3.0}]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTape(self): with backprop.GradientTape() as g: x = constant_op.constant(3.0) @@ -497,7 +497,7 @@ class BackpropTest(test.TestCase): grad = g.gradient(y, [x])[0] self.assertEqual(self.evaluate(grad), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeWithCond(self): x = constant_op.constant(3.0) @@ -518,7 +518,7 @@ class BackpropTest(test.TestCase): dy = g.gradient(y, [x])[0] self.assertEqual(self.evaluate(dy), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeWithWhileLoop(self): i = constant_op.constant(1) x = constant_op.constant(2.) @@ -553,7 +553,7 @@ class BackpropTest(test.TestCase): g.gradient(y, [x]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -567,7 +567,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testHigherOrderGradient(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -584,7 +584,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentNestedTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -605,7 +605,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeVariable(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') self.evaluate(v.initializer) @@ -615,7 +615,7 @@ class BackpropTest(test.TestCase): self.assertAllEqual(self.evaluate(grad), 2.0) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestedGradients(self): x = constant_op.constant(3.0) with backprop.GradientTape() as g: @@ -900,6 +900,33 @@ class BackpropTest(test.TestCase): 'did you forget to return a value from fn?'): val_and_grads_fn(x, y) + def testZerosCacheDoesntLeakAcrossModes(self): + with ops.Graph().as_default(): + t = random_ops.random_normal(shape=[100, 2]) + x = random_ops.random_normal(shape=[100, 4]) + dy = random_ops.random_normal(shape=[100, 4]) + with backprop.GradientTape() as gradient_tape: + gradient_tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1 ** 2. + y = array_ops.concat([y1, t], axis=1) + + dx = gradient_tape.gradient(y, x, output_gradients=dy) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(dx) + + t = random_ops.random_normal(shape=[100, 2]) + x = random_ops.random_normal(shape=[100, 4]) + dy = random_ops.random_normal(shape=[100, 4]) + with backprop.GradientTape() as gradient_tape: + gradient_tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1 ** 2. + y = array_ops.concat([y1, t], axis=1) + + dx = gradient_tape.gradient(y, x, output_gradients=dy) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 9e146f021e813886b42ca72b07122b485901a24b..85b9491903de2ea6ffe1c5ac7ef76efdfda2818b 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -143,7 +143,11 @@ class Context(object): # TODO(agarwal): create and link in some documentation for `execution_mode`. # pylint: disable=redefined-outer-name - def __init__(self, config=None, device_policy=None, execution_mode=None): + def __init__(self, + config=None, + device_policy=None, + execution_mode=None, + server_def=None): """Creates a new Context. Args: @@ -192,6 +196,7 @@ class Context(object): if execution_mode is None: execution_mode = SYNC self._execution_mode = execution_mode + self._server_def = server_def # pylint: enable=redefined-outer-name @@ -231,6 +236,9 @@ class Context(object): opts, self._device_policy) if self._execution_mode == ASYNC: pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) + if self._server_def is not None: + server_def_str = self._server_def.SerializeToString() + pywrap_tensorflow.TFE_ContextOptionsSetServerDef(opts, server_def_str) self._context_handle = pywrap_tensorflow.TFE_NewContext(opts) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index dd3166735ce2959e94203123e80ff2a33520188b..7edcb0931dd6ed31b285c14dd578bcc3d693c724 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import functools import numpy as np @@ -46,8 +47,11 @@ def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: - captured_value = graph_placeholder( - dtype=dtype or value.dtype, shape=value.shape, name=name) + # Note: setting ops.control_dependencies(None) ensures we always put + # capturing placeholders outside of any control flow context. + with ops.control_dependencies(None): + captured_value = graph_placeholder( + dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): @@ -313,7 +317,7 @@ class GraphModeFunction(object): graph, operations, outputs, - func_outputs, + python_func_outputs, output_shapes, variables=None, attrs=None): @@ -332,9 +336,10 @@ class GraphModeFunction(object): definition. outputs: a flat list of the Tensors in the graph used as outputs to the function - func_outputs: a possibly nested python object which will be returned by - this function. The Tensors in this structure will be replaced by their - corresponding values in outputs. + python_func_outputs: a possibly nested python object which will be + returned by this function. The Tensors in this structure will be + replaced by their corresponding values in outputs. Note that this + structure might contain Python `None`s. output_shapes: List of shapes of all tensors in outputs variables: (optional) List of variables to watch during function execution. @@ -356,9 +361,10 @@ class GraphModeFunction(object): self._function_def = defined_function self._num_outputs = len(defined_function.signature.output_arg) self._ops = operations - self._func_outputs = func_outputs - self._returns = [func_outputs] if isinstance( - func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs) + self._python_func_outputs = python_func_outputs + self._python_returns = [python_func_outputs] if isinstance( + python_func_outputs, + (ops.Tensor, type(None))) else _flatten(python_func_outputs) self._output_shapes = output_shapes self._variables = variables if variables is not None else [] @@ -373,7 +379,7 @@ class GraphModeFunction(object): c_captured_tensors = set() existing_op_len = len(self._graph.get_operations()) - filtered_outputs = [x for x in self._returns if x is not None] + filtered_outputs = [x for x in self._python_returns if x is not None] self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] in_gradients = gradients_impl.gradients( @@ -449,10 +455,16 @@ class GraphModeFunction(object): if not outputs: return op outputs = [outputs] if isinstance(outputs, ops.Tensor) else list(outputs) - for i, s in enumerate(self._output_shapes): - outputs[i].set_shape(s) - real_outputs = outputs[:len(self._returns)] - side_outputs = outputs[len(self._returns):] + + shapes = [shape for shape in self._output_shapes if shape is not None] + for i, shape in enumerate(shapes): + outputs[i].set_shape(shape) + + # `real_outputs` are the actual outputs of the inference graph function; + # `side_outputs` are the intermediate Tensors that were added as outputs to + # the forward graph function so that we can compute its gradient. + real_outputs = outputs[:self._num_outputs] + side_outputs = outputs[self._num_outputs:] def backward_function(*args): return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable @@ -469,8 +481,8 @@ class GraphModeFunction(object): def output_shapes(self): """The function's output shapes.""" # TODO(ebrevdo): Should we only keep the output shapes associated - # with len(self._returns) outputs? - outputs_list = nest.flatten(self._func_outputs) + # with len(self._python_returns) outputs? + outputs_list = nest.flatten(self._python_func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -484,12 +496,12 @@ class GraphModeFunction(object): else: outputs_list[i] = self._output_shapes[j] j += 1 - return nest.pack_sequence_as(self._func_outputs, outputs_list) + return nest.pack_sequence_as(self._python_func_outputs, outputs_list) @property def output_dtypes(self): return nest.map_structure( - lambda x: x.dtype if x is not None else None, self._func_outputs) + lambda x: x.dtype if x is not None else None, self._python_func_outputs) @property def captured_inputs(self): @@ -543,8 +555,10 @@ class GraphModeFunction(object): result = op.outputs if not result: return op - for i, s in enumerate(self._output_shapes): - result[i].set_shape(s) + + shapes = [shape for shape in self._output_shapes if shape is not None] + for i, shape in enumerate(shapes): + result[i].set_shape(shape) return self._build_call_outputs(result) @@ -556,11 +570,11 @@ class GraphModeFunction(object): Returns: The actual call output. """ - if self._func_outputs is None: + if self._python_func_outputs is None: return None # Use `nest.flatten` instead of `_flatten` in order to preserve any - # IndexedSlices in `self._func_outputs`. - outputs_list = nest.flatten(self._func_outputs) + # IndexedSlices in `self._python_func_outputs`. + outputs_list = nest.flatten(self._python_func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -580,7 +594,7 @@ class GraphModeFunction(object): else: outputs_list[i] = result[j] j += 1 - ret = nest.pack_sequence_as(self._func_outputs, outputs_list) + ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list) return ret @@ -596,6 +610,10 @@ def _get_defun_inputs(args): return nest.pack_sequence_as(args, ret) +def _deterministic_dict_values(kwds): + return tuple(kwds[key] for key in sorted(kwds)) + + def _trace_and_define_function(name, func, compiled, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access @@ -613,7 +631,8 @@ def _trace_and_define_function(name, func, compiled, args, kwds): tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( collection) with tmp_graph.as_default(), AutomaticControlDependencies() as a: - func_inputs = _get_defun_inputs(args) + func_args = _get_defun_inputs(args) + func_kwds = _get_defun_inputs(kwds) def convert(x): if x is None: @@ -624,7 +643,7 @@ def _trace_and_define_function(name, func, compiled, args, kwds): this_tape = tape.push_new_tape() try: - func_outputs = func(*func_inputs, **kwds) + func_outputs = func(*func_args, **func_kwds) func_outputs = nest.map_structure(convert, func_outputs) finally: tape.pop_tape(this_tape) @@ -648,8 +667,11 @@ def _trace_and_define_function(name, func, compiled, args, kwds): x.shape if isinstance(x, ops.Tensor) else None for x in outputs_list) - flat_inputs = [x for x in nest.flatten(func_inputs) - if isinstance(x, ops.Tensor)] + func_kwds_values = _deterministic_dict_values(func_kwds) + flat_inputs = [ + x for x in nest.flatten(func_args) + nest.flatten(func_kwds_values) + if isinstance(x, ops.Tensor) + ] all_inputs = flat_inputs + list(extra_placeholders) all_ignored_ops = frozenset(x.op for x in all_inputs) fname = _inference_name(name) @@ -726,30 +748,58 @@ class _PolymorphicFunction(object): self._arguments_to_functions = {} self._variables = [] + def __get__(self, instance, owner): + """Makes it possible to defun instance methods.""" + del owner + # `instance` here is the instance that this `_PolymorphicFunction` was + # accessed through; e.g., for + # + # class Foo(object): + # + # @function.defun + # def bar(self): + # ... + # + # foo = Foo() + # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance + # + # then `instance` will be `foo` (and `owner` will be `Foo`). + return functools.partial(self.__call__, instance) + def _maybe_define_function(self, *args, **kwds): - """Gets a function for these inputs, defining it if necessary.""" + """Gets a function for these inputs, defining it if necessary. - # TODO(akshayka): Remove this restriction. - if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): - raise ValueError("Tensor keyword arguments are not supported.") + Args: + *args: args for the Python function; used to compute the signature + **kwds: kwds for the Python function; used to compute the signature - # TODO(apassos): Better error messages for non-hashable arguments. - cache_key = tuple(_cache_key(x) for x in args) - cache_key = (cache_key, tuple(kwds.items())) + Returns: + A graph function corresponding to the input signature implied by args and + kwds, as well as the inputs that the object should be called with. + """ - if cache_key not in self._arguments_to_functions: + # TODO(apassos): Better error messages for non-hashable arguments. + kwd_values = _deterministic_dict_values(kwds) + inputs = args + kwd_values + signature = tuple(_cache_key(x) for x in inputs) + # The graph, or whether we're executing eagerly, should be a part of the + # signature so we don't improperly capture tensors such as variables. + signature += tuple([context.executing_eagerly() or ops.get_default_graph()]) + + if signature not in self._arguments_to_functions: graph_function = _trace_and_define_function( self._name, self._python_function, self._compiled, args, kwds) - self._arguments_to_functions[cache_key] = graph_function + self._arguments_to_functions[signature] = graph_function self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) - return graph_function + return graph_function, inputs else: - return self._arguments_to_functions[cache_key] + return self._arguments_to_functions[signature], inputs def __call__(self, *args, **kwds): """Calls a graph function specialized for this input signature.""" - return self._maybe_define_function(*args, **kwds)(*args) + graph_function, inputs = self._maybe_define_function(*args, **kwds) + return graph_function(*inputs) @property def variables(self): @@ -765,22 +815,28 @@ def defun(func=None, compiled=False): `defun` (short for "define function") trace-compiles a Python function composed of TensorFlow operations into a callable that executes a @{tf.Graph} - containing those operations. When eager execution is enabled, the ability to - create graphs from Python functions makes it possible to incrementally trade - off debugability and interactivity for performance. Functions compiled with - `defun` cannot be inspected with `pdb` and `print` statements; however, - executing a graph generated by `defun` sometimes takes less time and memory - than eagerly executing the corresponding Python function, since specifying - computations as graphs allows for optimizations like automatic buffer reuse - and parallelization among ops. Note that executing a `defun`-compiled function + containing those operations. The callable produced by `defun` contains only + the subgraph of TensorFlow operations that were executed when the Python + function was called with a particular input signature, defined as a list + of the shapes and dtypes of the Python function's Tensor-valued arguments and + the values of its non-Tensor Python objects. In particular, `defun` is _not_ a + compiler for arbitrary Python code. + + When eager execution is enabled, the ability to create graphs from Python + functions makes it possible to incrementally trade off debugability and + interactivity for performance. Functions compiled with `defun` cannot be + inspected with `pdb` and `print` statements; however, executing a graph + generated by `defun` sometimes takes less time and memory than eagerly + executing the corresponding Python function, since specifying computations as + graphs allows for optimizations like automatic buffer reuse and + parallelization among ops. Note that executing a `defun`-compiled function incurs a small constant overhead, so eagerly executing sufficiently small Python functions might take less time than executing their corresponding `defun`-generated graphs. - For a Python function to be compatible with `defun`, the values of its keyword - arguments cannot be Tensors and all of its arguments, including its keyword - arguments, must be hashable Python objects or lists thereof. Additionally, it - must return zero or more @{tf.Tensor} objects. + For a Python function to be compatible with `defun`, all of its arguments must + be hashable Python objects or lists thereof. Additionally, it must return zero + or more @{tf.Tensor} objects. _Example Usage_ @@ -853,20 +909,23 @@ def defun(func=None, compiled=False): _Tracing and Input Signatures_. The signature of inputs supplied to `F` is defined to be a tuple of the shapes - and dtypes of Tensor-typed arguments and the values of non-Tensor arguments - and keyword arguments. Every time `F` is invoked, the signature of its inputs - are inferred. The first time `F(*args, **kwargs)` is invoked with a particular - signature, `f(*args, **kwargs)` is executed and all the TensorFlow operations - that `f` executes, along with the Tensors that flow between them, are recorded - in a TensorFlow graph. `F` caches this graph and binds it to the inputs' - signature; every subsequent invocation of `F` with inputs conforming to this - signature will immediately retrieve the cached graph and pass it to the - TensorFlow runtime for execution. - - Be aware that because `F` only logs TensorFlow operations, all non-TensorFlow - operations that `f` executes will only shape the _construction_ of the graphs - that `F` executes: They won't be executed when the graphs themselves are - executed. For example, whereas the Python function + and dtypes of Tensor-typed arguments and the values of non-Tensor arguments, + where "arguments" includes both args and kwargs. Every time `F` is invoked, + the signature of its inputs are inferred. The first time `F(*args, **kwargs)` + is invoked with a particular signature, `f(*args, **kwargs)` is executed and + all the TensorFlow operations that `f` executes, along with the Tensors that + flow between them, are recorded in a TensorFlow graph. `F` caches this graph + and binds it to the inputs' signature; every subsequent invocation of `F` with + inputs conforming to this signature will immediately retrieve the cached graph + and pass it to the TensorFlow runtime for execution. + + Be aware that because `F` only logs TensorFlow operations, all the other + Python code that `f` executes will only shape the _construction_ of the graphs + that `F` executes: the Python code won't be executed when the graphs + themselves are executed, though it will be executed every time the Python + function is traced (and a given Python function might be traced multiple + times, once for each input signature it is invoked with). For example, whereas + the Python function ```python import tensorflow as tf @@ -874,17 +933,23 @@ def defun(func=None, compiled=False): tf.enable_eager_execution() - matrix = tf.eye(5) - # `matrix` is assumed to be a Tensor def add_noise(): - return matrix + np.random.randn(matrix.shape[0], matrix.shape[1]) + return tf.eye(5) + np.random.randn(5, 5) ``` will return a different output everytime it is invoked, the compiled function `compiled = tf.contrib.eager.defun(add_noise)` will return the same value every time it is called, since a particular random offset generated by NumPy will be inserted into the graph as a TensorFlow constant. The solution is to - replace the call to `np.random.randn` with `tf.random_normal(matrix.shape)`. + replace the call to `np.random.randn` with `tf.random_normal((5, 5))`. + + _Python Side-Effects_ + A corollary of the previous discussion on tracing is the following: If a + Python function `f` has Python side-effects, then executing `f` multiple times + will not necessarily be semantically equivalent to executing `F = + tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact + that `defun` only captures the subgraph of TensorFlow operations that is + constructed when `f` is called in a graph-building context. _Python Control Flow_. The structure of many machine learning computations depend upon whether one is @@ -1068,15 +1133,8 @@ def make_defun_op(func, *args, **kwds): A wrapper object which can be queried for its output properties, and which can be called directly the way a `@defun` wrapped function can. - - Raises: - ValueError: if any of the keyword arguments to `func` are `EagerTensor` - objects (not yet supported). """ - name = func.__name__ - if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): - raise ValueError("Tensor keyword arguments are not supported.") - return _trace_and_define_function(name, func, False, args, kwds) + return _trace_and_define_function(func.__name__, func, False, args, kwds) class AutomaticControlDependencies(object): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 6ce2ceffda7bb08717c4cccab03eadacba5c6655..cf32f6e7fb238992df793ef1707e44c40ccde980 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -34,6 +34,7 @@ from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -90,6 +91,32 @@ class FunctionTest(test.TestCase): self.assertAllEqual(step(), 2.0) + def testGraphGradientVariable(self): + with ops.Graph().as_default(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(): + return 2.0 * v + + node = f() + grads, = gradients_impl.gradients(node, v) + v.initializer.run() + self.assertAllEqual(grads.eval(), 2.0) + self.assertEqual(grads.shape, v.shape) + + def testGraphEagerIsolation(self): + + @function.defun + def f(): + v = resource_variable_ops.ResourceVariable(1.0) + return v.read_value() + + self.assertAllEqual(f(), 1.0) + + with ops.Graph().as_default(): + self.assertEqual(f().shape, ()) + def testBasicDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) @@ -196,6 +223,21 @@ class FunctionTest(test.TestCase): compiled = function.defun(f) compiled() + def testVariableInLoopInFunction(self): + + @function.defun + def test_function(): + + def loop_test(_): + return False + + def loop_body(_): + return variable_scope.get_variable('a', shape=()) + + return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) + + self.assertEqual(test_function().shape, []) + def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): with context.graph_mode(): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) @@ -512,6 +554,60 @@ class FunctionTest(test.TestCase): g = backprop.gradients_function(wrapper, [0])(constant_op.constant(0.0)) self.assertAllEqual(g[0], 1.) + @function.defun + def foo(a): + return None, a * a + + x = constant_op.constant(5.0) + with backprop.GradientTape() as tp: + tp.watch(x) + none, r = foo(x) + g = tp.gradient(r, x) + + self.assertIs(none, None) + self.assertAllEqual(r, 25.0) + self.assertAllEqual(g, 2 * 5.0) + + def testNestedDifferentiableFunction(self): + @function.defun + def foo(a, b): + return a * math_ops.add(a, b) + + @function.defun + def bar(x): + return foo(x, 1.0) + + x = constant_op.constant(5.0) + with backprop.GradientTape() as tp: + tp.watch(x) + result = bar(x) + grad = tp.gradient(result, x) + + self.assertAllEqual(grad, 2 * 5.0 + 1.0) + + def testNestedDifferentiableFunctionNoneOutputs(self): + @function.defun + def foo(a, b): + return None, a * math_ops.add(a, b), None, 2*a + + @function.defun + def bar(x): + return foo(x, 1.0) + + x = constant_op.constant(5.0) + with backprop.GradientTape(persistent=True) as tp: + tp.watch(x) + none1, r1, none2, r2 = bar(x) + g1 = tp.gradient(r1, x) + g2 = tp.gradient(r2, x) + + self.assertAllEqual(r1, 30.0) + self.assertAllEqual(r2, 10.0) + self.assertIs(none1, None) + self.assertIs(none2, None) + self.assertAllEqual(g1, 2 * 5.0 + 1.0) + self.assertAllEqual(g2, 2.0) + def testNoneOutput(self): @function.defun @@ -650,6 +746,89 @@ class FunctionTest(test.TestCase): _ = defined(x) # ensure the variables list remains the same self.assertAllEqual(defined.variables, [v]) + def testTensorKeywordArguments(self): + + def foo(a, b): + del a + return b + + defined = function.defun(foo) + a = constant_op.constant(2.0) + b = constant_op.constant([1.0, 2.0]) + one = defined(a, b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + two = defined(a=a, b=b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + three = defined(b=b, a=a) + self.assertEqual(len(defined._arguments_to_functions), 1) + + four = defined(a, b=b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + # The next call corresponds to a new input signature, hence + # we expect another function to be defined. + five = defined(b, a) + self.assertEqual(len(defined._arguments_to_functions), 2) + + six = defined(a=b, b=a) + self.assertEqual(len(defined._arguments_to_functions), 2) + + seven = defined(b=a, a=b) + self.assertEqual(len(defined._arguments_to_functions), 2) + + self.assertAllEqual(one, [1.0, 2.0]) + self.assertAllEqual(two, [1.0, 2.0]) + self.assertAllEqual(three, [1.0, 2.0]) + self.assertAllEqual(four, [1.0, 2.0]) + self.assertAllEqual(five, 2.0) + self.assertAllEqual(six, 2.0) + self.assertAllEqual(seven, 2.0) + + def testGradientWithKeywordArguments(self): + matmul = function.defun(math_ops.matmul) + + def sq(x): + return matmul(a=x, b=x, transpose_a=True) + + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + grad_t, = backprop.gradients_function(sq, [0])(t) + self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) + + with backprop.GradientTape(persistent=True) as gtape: + gtape.watch(t) + one = matmul(t, b=t, transpose_a=True) + two = matmul(b=t, a=t, transpose_a=True) + three = matmul(a=t, b=t, transpose_a=True) + + for output in [one, two, three]: + self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]]) + + def testGradientInFunctionWithKeywordArguments(self): + + @function.defun + def f(x): + return backprop.gradients_function(lambda y: y * y, [0])(x)[0] + + self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0) + + def testDecoratingInstanceMethod(self): + + class Foo(object): + + def one(self, tensor): + return tensor + + @function.defun + def two(self, tensor): + return self.one(tensor) + + foo = Foo() + t = constant_op.constant(1.0) + out = foo.two(t) + self.assertEqual(float(out), 1.0) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 6c9481c3af31a96308bb47f6e2aa35988f16f709..57b4dab51cc766042dfa895b197b3e3de037269d 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -205,14 +205,20 @@ bool ParseDimensionValue(const string& key, PyObject* py_value, } bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, - const char** value) { + tensorflow::StringPiece* value) { if (PyBytes_Check(py_value)) { - *value = PyBytes_AsString(py_value); + Py_ssize_t size = 0; + char* buf = nullptr; + if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; + *value = tensorflow::StringPiece(buf, size); return true; } #if PY_MAJOR_VERSION >= 3 if (PyUnicode_Check(py_value)) { - *value = PyUnicode_AsUTF8(py_value); + Py_ssize_t size = 0; + char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); + if (buf == nullptr) return false; + *value = tensorflow::StringPiece(buf, size); return true; } #endif @@ -275,8 +281,16 @@ bool SetOpAttrList( } if (type == TF_ATTR_STRING) { - PARSE_LIST(const char*, ParseStringValue); - TFE_OpSetAttrStringList(op, key, values.get(), num_values); + std::unique_ptr values(new const void*[num_values]); + std::unique_ptr lengths(new size_t[num_values]); + for (int i = 0; i < num_values; ++i) { + tensorflow::StringPiece value; + tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); + if (!ParseStringValue(key, py_value.get(), status, &value)) return false; + values[i] = value.data(); + lengths[i] = value.size(); + } + TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); } else if (type == TF_ATTR_INT) { PARSE_LIST(int64_t, ParseInt64Value); TFE_OpSetAttrIntList(op, key, values.get(), num_values); @@ -379,12 +393,15 @@ void SetOpAttrListDefault( TF_Status* status) { if (type == TF_ATTR_STRING) { int num_values = attr.default_value().list().s_size(); - std::unique_ptr values(new const char*[num_values]); + std::unique_ptr values(new const void*[num_values]); + std::unique_ptr lengths(new size_t[num_values]); (*attr_list_sizes)[key] = num_values; for (int i = 0; i < num_values; i++) { - values[i] = attr.default_value().list().s(i).data(); + const string& v = attr.default_value().list().s(i); + values[i] = v.data(); + lengths[i] = v.size(); } - TFE_OpSetAttrStringList(op, key, values.get(), num_values); + TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); } else if (type == TF_ATTR_INT) { int num_values = attr.default_value().list().i_size(); std::unique_ptr values(new int64_t[num_values]); @@ -470,9 +487,9 @@ bool SetOpAttrScalar( tensorflow::gtl::FlatMap* attr_list_sizes, TF_Status* status) { if (type == TF_ATTR_STRING) { - const char* value; + tensorflow::StringPiece value; if (!ParseStringValue(key, py_value, status, &value)) return false; - TFE_OpSetAttrString(op, key, value); + TFE_OpSetAttrString(op, key, value.data(), value.size()); } else if (type == TF_ATTR_INT) { int64_t value; if (!ParseInt64Value(key, py_value, status, &value)) return false; @@ -533,7 +550,7 @@ bool SetOpAttrScalar( // (which is what the various "defun" or "Defun" decorators do). // And in the future also allow an object that can encapsulate // the function name and its attribute values. - const char* func_name = nullptr; + tensorflow::StringPiece func_name; if (!ParseStringValue(key, py_value, status, &func_name)) { PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); if (name_attr == nullptr || @@ -549,7 +566,8 @@ bool SetOpAttrScalar( return false; } } - TFE_Op* func = TFE_NewOp(ctx, func_name, status); + TFE_Op* func = TFE_NewOp( + ctx, string(func_name.data(), func_name.size()).c_str(), status); if (TF_GetCode(status) != TF_OK) return false; TFE_OpSetAttrFunction(op, key, func); TFE_DeleteOp(func); @@ -930,7 +948,7 @@ class GradientTape : id(id), variable(variable) {} }; struct CompareById { - bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) { + bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { return lhs.id < rhs.id; } }; diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 9e716e81f40a2395b6ff04989f695dc5c0d91d15..8ee38d35cc152e6c281e83d7fd49540ddaee2a7e 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -1,8 +1,4 @@ -package( - default_visibility = [ - "//tensorflow:internal", - ], -) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 @@ -10,7 +6,10 @@ load("//tensorflow:tensorflow.bzl", "py_test") py_library( name = "estimator_py", - srcs = ["estimator_lib.py"], + srcs = [ + "__init__.py", + "estimator_lib.py", + ], srcs_version = "PY2AND3", visibility = [ "//tensorflow:__pkg__", @@ -31,7 +30,7 @@ py_library( ":parsing_utils", ":run_config", ":training", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -41,10 +40,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":gc", - "//tensorflow/python:errors", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:util", ], @@ -58,10 +54,7 @@ py_test( deps = [ ":estimator", ":exporter", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -70,8 +63,7 @@ py_library( srcs = ["gc.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -82,10 +74,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":gc", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -95,12 +84,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":export_output", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -113,12 +97,7 @@ py_test( deps = [ ":export_output", ":model_fn", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -130,11 +109,7 @@ py_library( ":estimator", ":exporter", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -153,13 +128,7 @@ py_test( ":inputs", ":run_config", ":training", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -168,7 +137,7 @@ py_library( srcs = ["run_config.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -180,8 +149,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -194,14 +162,7 @@ py_library( ":head", ":model_fn", ":optimizers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -225,26 +186,7 @@ py_test( ":numpy_io", ":pandas_io", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -257,20 +199,7 @@ py_library( ":estimator", ":head", ":model_fn", - "//tensorflow/python:array_ops", - "//tensorflow/python:boosted_trees_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -279,21 +208,13 @@ py_test( size = "medium", srcs = ["canned/boosted_trees_test.py"], srcs_version = "PY2AND3", + tags = [ + "optonly", + ], deps = [ ":boosted_trees", - "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:resources", - "//tensorflow/python:training", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/feature_column", + ":inputs", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -306,14 +227,7 @@ py_library( ":head", ":model_fn", ":optimizers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:summary", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -330,22 +244,7 @@ py_library( ":model_fn", ":numpy_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -368,16 +267,7 @@ py_test( ":numpy_io", ":pandas_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -393,19 +283,7 @@ py_library( ":linear", ":model_fn", ":optimizers", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -428,17 +306,7 @@ py_test( ":numpy_io", ":pandas_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:nn", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -450,10 +318,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -464,10 +329,7 @@ py_test( tags = ["notsan"], # b/67510291 deps = [ ":util", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:training", - "//tensorflow/python/data", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -484,21 +346,7 @@ py_library( ":model_fn", ":run_config", ":util", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:constants", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -517,29 +365,7 @@ py_test( ":model_fn", ":numpy_io", ":run_config", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:lib", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:saver_test_utils", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -552,9 +378,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:parsing_ops", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -565,10 +389,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":parsing_utils", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:parsing_ops", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -577,9 +398,7 @@ py_library( srcs = ["export/export_output.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -591,13 +410,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":export_output", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -610,7 +423,7 @@ py_library( deps = [ ":export_export", ":export_output", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -622,13 +435,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":util", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -641,17 +448,8 @@ py_test( deps = [ ":export_export", ":export_output", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:signature_def_utils", + ":util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -664,24 +462,7 @@ py_library( ":metric_keys", ":model_fn", ":prediction_keys", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:weights_broadcast_ops", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -700,23 +481,7 @@ py_test( ":model_fn", ":numpy_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -729,7 +494,7 @@ py_library( deps = [ ":numpy_io", ":pandas_io", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -741,11 +506,7 @@ py_library( ":estimator", ":head", ":optimizers", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -763,25 +524,7 @@ py_library( ":numpy_io", ":pandas_io", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -799,7 +542,7 @@ py_test( deps = [ ":linear", ":linear_testing_utils", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -828,9 +571,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":numpy_io", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -839,7 +580,7 @@ py_library( srcs = ["canned/optimizers.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -851,8 +592,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":optimizers", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -870,9 +610,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":pandas_io", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -892,15 +630,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -914,7 +644,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":inputs_queues", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -925,10 +655,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":inputs_queues", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -941,32 +668,7 @@ py_library( ":export_export", ":model_fn", ":run_config", - "//tensorflow/python:check_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_util", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:engine", - "//tensorflow/python/keras:layers", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -981,18 +683,41 @@ py_test( ], deps = [ ":keras", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:run_config", - "//tensorflow/python/keras", - "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:engine", "//third_party/py/numpy", ], ) + +py_library( + name = "expect_numpy_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect numpy to already be installed on the system, e.g. via + # `pip install numpy` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_pandas_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect pandas to already be installed on the system, e.g. via + # `pip install pandas` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_six_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect six to already be installed on the system, e.g. via + # `pip install six` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_tensorflow_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect tensorflow to already be installed on the system, e.g. via + # `pip install tensorflow` or `pip install tensorflow_gpu` + visibility = ["//visibility:public"], +) diff --git a/tensorflow/python/estimator/__init__.py b/tensorflow/python/estimator/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..8cf8df567f0e36604b5c3f6fe992b572d6632954 100644 --- a/tensorflow/python/estimator/__init__.py +++ b/tensorflow/python/estimator/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Import Estimator APIs. + +Note: This file is imported by the create_estimator_api genrule. It must +transitively import all Estimator modules/packages for their @estimator_export +annotations to generate the public Estimator python API. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.estimator.estimator_lib diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD index cddee9b8f30555da63a9aad1190a7644d02e5392..aa5a29e6dd148c39ebb098cb99cb1907d9c5a9d9 100644 --- a/tensorflow/python/estimator/api/BUILD +++ b/tensorflow/python/estimator/api/BUILD @@ -14,4 +14,5 @@ gen_api_init_files( api_name = "estimator", output_files = ESTIMATOR_API_INIT_FILES, package = "tensorflow.python.estimator", + package_dep = "//tensorflow/python/estimator:estimator_py", ) diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 86dbf272efeaa710c7327b8fa4122827cb248af8..8afef1b65a8d57e2b7ce3e4e512c622ca107ab83 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -168,9 +168,10 @@ def _group_features_by_num_buckets(sorted_feature_columns): # pylint:enable=protected-access # Replace the dummy key with the real max num of buckets for all bucketized # columns. - bucket_size_to_feature_ids_dict[ - max_buckets_for_bucketized] = bucket_size_to_feature_ids_dict[ - _DUMMY_NUM_BUCKETS] + if max_buckets_for_bucketized not in bucket_size_to_feature_ids_dict: + bucket_size_to_feature_ids_dict[max_buckets_for_bucketized] = [] + bucket_size_to_feature_ids_dict[max_buckets_for_bucketized].extend( + bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS]) del bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS] feature_ids_list = list(bucket_size_to_feature_ids_dict.values()) diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 9ea4f484744762a98c67207d582bcc5b7be8d850..33e9e69b041a7d250c9d86bdf8912bf0585f7d81 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -500,6 +500,50 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertEqual(2, ensemble.trees[0].nodes[0].bucketized_split.feature_id) self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + def testTrainEvaluateAndPredictWithOnlyIndicatorColumn(self): + categorical = feature_column.categorical_column_with_vocabulary_list( + key='categorical', vocabulary_list=('bad', 'good', 'ok')) + feature_indicator = feature_column.indicator_column(categorical) + + labels = np.array([[0.], [5.7], [5.7], [0.], [0.]], dtype=np.float32) + # Our categorical feature defines the labels perfectly + input_fn = numpy_io.numpy_input_fn( + x={ + 'categorical': np.array(['bad', 'good', 'good', 'ok', 'bad']), + }, + y=labels, + batch_size=5, + shuffle=False) + + # Train depth 1 tree. + est = boosted_trees.BoostedTreesRegressor( + feature_columns=[feature_indicator], + n_batches_per_layer=1, + n_trees=1, + learning_rate=1.0, + max_depth=1) + + num_steps = 1 + est.train(input_fn, steps=num_steps) + ensemble = self._assert_checkpoint_and_return_model( + est.model_dir, global_step=1, finalized_trees=1, attempted_layers=1) + + # We learnt perfectly. + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['loss'], 0) + + predictions = list(est.predict(input_fn)) + self.assertAllClose( + labels, + [pred['predictions'] for pred in predictions]) + + self.assertEqual(3, len(ensemble.trees[0].nodes)) + + # Check that the split happened on 'good' value, which will be encoded as + # feature with index 1 (0 - 'bad', 2 - 'ok') + self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id) + self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + class ModelFnTests(test_util.TensorFlowTestCase): """Tests bt_model_fn including unexposed internal functionalities.""" diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 90889e3e5d9f022f53c1f9f754bb01ae0a292f9c..2c7c4285caadf70777d43a9c30b1d8e95b8158ab 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -230,6 +230,17 @@ class DNNClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNClassifier( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNClassifier( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -317,8 +328,9 @@ class DNNClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -385,6 +397,17 @@ class DNNRegressor(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNRegressor( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNRegressor( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -465,8 +488,9 @@ class DNNRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 3d1ad1365bc66b3ef8b973257dd8b86ded0ea847..2f20e4b289f4fd55bc872bacc2de36bceade49dc 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -257,12 +257,19 @@ class DNNLinearCombinedClassifier(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -325,12 +332,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -441,12 +452,19 @@ class DNNLinearCombinedRegressor(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -508,12 +526,16 @@ class DNNLinearCombinedRegressor(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index ac59e786c414f2093f8ab2c6eeb26101acdb2600..e22df849e52000e125c6bf2015485e3496f8bb8d 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -193,6 +193,17 @@ class LinearClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearClassifier( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearClassifier( feature_columns=[categorical_column_a, @@ -272,8 +283,9 @@ class LinearClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or @@ -335,10 +347,31 @@ class LinearRegressor(estimator.Estimator): categorical_feature_a_x_categorical_feature_b = crossed_column(...) + # Estimator using the default optimizer. estimator = LinearRegressor( feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using the FTRL optimizer with regularization. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=tf.train.FtrlOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001 + )) + + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearRegressor( feature_columns=[categorical_column_a, @@ -409,8 +442,9 @@ class LinearRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or diff --git a/tensorflow/python/estimator/canned/optimizers.py b/tensorflow/python/estimator/canned/optimizers.py index f72c5ca5cbb2721d967ad9ef9dfa896f7ccce240..8f51cc3a80dd9b91eb24a83577b7d0614615e008 100644 --- a/tensorflow/python/estimator/canned/optimizers.py +++ b/tensorflow/python/estimator/canned/optimizers.py @@ -72,6 +72,8 @@ def get_optimizer_instance(opt, learning_rate=None): raise ValueError( 'Unsupported optimizer name: {}. Supported names are: {}'.format( opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES))))) + if callable(opt): + opt = opt() if not isinstance(opt, optimizer_lib.Optimizer): raise ValueError( 'The given object is not an Optimizer instance. Given: {}'.format(opt)) diff --git a/tensorflow/python/estimator/canned/optimizers_test.py b/tensorflow/python/estimator/canned/optimizers_test.py index ee28756155afd5ae3421475c3d41542db9411345..eadabdbc496334270cd792f5b8d5ff39a446bcf7 100644 --- a/tensorflow/python/estimator/canned/optimizers_test.py +++ b/tensorflow/python/estimator/canned/optimizers_test.py @@ -28,6 +28,13 @@ from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import rmsprop +class _TestOptimizer(optimizer_lib.Optimizer): + + def __init__(self): + super(_TestOptimizer, self).__init__( + use_locking=False, name='TestOptimizer') + + class GetOptimizerInstance(test.TestCase): def test_unsupported_name(self): @@ -66,12 +73,6 @@ class GetOptimizerInstance(test.TestCase): self.assertAlmostEqual(0.1, opt._learning_rate) def test_object(self): - class _TestOptimizer(optimizer_lib.Optimizer): - - def __init__(self): - super(_TestOptimizer, self).__init__( - use_locking=False, name='TestOptimizer') - opt = optimizers.get_optimizer_instance(_TestOptimizer()) self.assertIsInstance(opt, _TestOptimizer) @@ -80,6 +81,23 @@ class GetOptimizerInstance(test.TestCase): ValueError, 'The given object is not an Optimizer instance'): optimizers.get_optimizer_instance((1, 2, 3)) + def test_callable(self): + def _optimizer_fn(): + return _TestOptimizer() + opt = optimizers.get_optimizer_instance(_optimizer_fn) + self.assertIsInstance(opt, _TestOptimizer) + + def test_lambda(self): + opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer()) # pylint: disable=unnecessary-lambda + self.assertIsInstance(opt, _TestOptimizer) + + def test_callable_returns_invalid(self): + def _optimizer_fn(): + return (1, 2, 3) + with self.assertRaisesRegexp( + ValueError, 'The given object is not an Optimizer instance'): + optimizers.get_optimizer_instance(_optimizer_fn) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 2b87f7403fa91ddf9c3ccc8042ea8b2c1ea499cd..350a95eea1f1112ea270156855409d7a1b264bfb 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -38,6 +38,7 @@ from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -848,7 +849,8 @@ class Estimator(object): strip_default_attrs, save_variables=True, mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None): + export_tags=None, + check_variables=True): # pylint: disable=line-too-long """Loads variables and adds them along with a MetaGraphDef for saving. @@ -869,6 +871,10 @@ class Estimator(object): mode: tf.estimator.ModeKeys value indicating which mode will be exported. export_tags: The set of tags with which to save `MetaGraphDef`. If None, a default set will be selected to matched the passed mode. + check_variables: bool, whether to check the checkpoint has all variables. + + Raises: + ValueError: if `save_variables` is `True` and `check_variable` is `False`. """ # pylint: enable=line-too-long if export_tags is None: @@ -909,16 +915,20 @@ class Estimator(object): # SavedModel for restore later. graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True) - try: - graph_saver.restore(session, checkpoint_path) - except errors.NotFoundError as e: - msg = ('Could not load all requested variables from the checkpoint. ' - 'Please make sure your model_fn does not expect variables ' - 'that were not saved in the checkpoint.\n\n' - 'Encountered error with mode `{}` while restoring checkpoint ' - 'from: `{}`. Full Traceback:\n\n{}').format( - mode, checkpoint_path, e) - raise ValueError(msg) + if save_variables and not check_variables: + raise ValueError('If `save_variables` is `True, `check_variables`' + 'must not be `False`.') + if check_variables: + try: + graph_saver.restore(session, checkpoint_path) + except errors.NotFoundError as e: + msg = ('Could not load all requested variables from checkpoint. ' + 'Please make sure your model_fn does not expect variables ' + 'that were not saved in the checkpoint.\n\n' + 'Encountered error with mode `{}` while restoring ' + 'checkpoint from: `{}`. Full Traceback:\n\n{}').format( + mode, checkpoint_path, e) + raise ValueError(msg) # We add the train op explicitly for now, so that we don't have to # change the Builder public interface. Note that this is a no-op @@ -1133,6 +1143,18 @@ class Estimator(object): return self._train_model_default(input_fn, hooks, saving_listeners) def _train_model_default(self, input_fn, hooks, saving_listeners): + """Initiate training with input_fn, without DistributionStrategies. + + Args: + input_fn: A function that provides input data for training as minibatches. + hooks: List of `SessionRunHook` subclass instances. Used for callbacks + inside the training loop. + saving_listeners: list of `CheckpointSaverListener` objects. Used for + callbacks that run immediately before or after checkpoint savings. + + Returns: + Loss from training + """ worker_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) @@ -1149,26 +1171,86 @@ class Estimator(object): saving_listeners) def _train_model_distributed(self, input_fn, hooks, saving_listeners): + """Initiate training with input_fn, using DistributionStrategies. + + Args: + input_fn: A function that provides input data for training as minibatches. + hooks: List of `SessionRunHook` subclass instances. Used for callbacks + inside the training loop. + saving_listeners: list of `CheckpointSaverListener` objects. Used for + callbacks that run immediately before or after checkpoint savings. + + Returns: + Loss from training + """ self._distribution.configure(self._session_config) + + # TODO(sourabhbajaj): Remove this hack once we migrate the other strategies + # to use the new API + is_tpu_strategy = self._distribution.__class__.__name__ == 'TPUStrategy' + worker_hooks = [] with ops.Graph().as_default() as g: with self._distribution.scope(): random_seed.set_random_seed(self._config.tf_random_seed) - features, labels, input_hooks = ( - self._get_features_and_labels_from_input_fn( - input_fn, model_fn_lib.ModeKeys.TRAIN)) - worker_hooks.extend(input_hooks) - global_step_tensor = self._create_and_assert_global_step(g) - # we want to add to the global collection in the main thread not the - # tower threads. - ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, - self._distribution.read_var(global_step_tensor)) - grouped_estimator_spec = self._distribution.call_for_each_tower( - self._call_model_fn, - features, - labels, # although this will be None it seems - model_fn_lib.ModeKeys.TRAIN, - self.config) + + if is_tpu_strategy: + # Create the iterator for run_on_dataset function + # TODO(sourabhbajaj): refactor this out to call a function on the + # strategy + dataset = self._distribution.distribute_dataset( + lambda: self._call_input_fn(input_fn, # pylint: disable=g-long-lambda + model_fn_lib.ModeKeys.TRAIN)) + iterator = dataset.make_initializable_iterator() + worker_hooks.append( + estimator_util._DatasetInitializerHook(iterator)) # pylint: disable=protected-access + + global_step_tensor = self._create_and_assert_global_step(g) + # we want to add to the global collection in the main thread not the + # tower threads. + ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, + self._distribution.read_var(global_step_tensor)) + + # Create a step_fn from the train_op of grouped_estimator_spec + def step_fn(ctx, inputs): + """A single step that is passed to run_on_dataset.""" + features, labels = inputs + estimator_spec = self._distribution.call_for_each_tower( + self._call_model_fn, + features, + labels, + model_fn_lib.ModeKeys.TRAIN, + self.config) + ctx.last_step_outputs = estimator_spec.loss + ctx.non_tensor_outputs = {'estimator_spec': estimator_spec} + with ops.control_dependencies([estimator_spec.train_op]): + return array_ops.identity(estimator_spec.loss) + + # Create new train_op post graph rewrites + # TODO(sourabhbajaj): Make sure train_steps and tpu_iterations + # work correctly. Currently hardcoded at 2 + initial_training_loss = constant_op.constant(1e7) + distributed_train_op, tpu_result, ctx = \ + self._distribution._run_steps_on_dataset( # pylint: disable=protected-access + step_fn, iterator, iterations=2, + initial_loop_values=initial_training_loss) + grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec'] + else: + features, labels, input_hooks = ( + self._get_features_and_labels_from_input_fn( + input_fn, model_fn_lib.ModeKeys.TRAIN)) + worker_hooks.extend(input_hooks) + global_step_tensor = self._create_and_assert_global_step(g) + # we want to add to the global collection in the main thread not the + # tower threads. + ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, + self._distribution.read_var(global_step_tensor)) + grouped_estimator_spec = self._distribution.call_for_each_tower( + self._call_model_fn, + features, + labels, # although this will be None it seems + model_fn_lib.ModeKeys.TRAIN, + self.config) # TODO(anjalisridhar): Figure out how to resolve the following scaffold # parameters: init_feed_dict, init_fn. @@ -1196,10 +1278,16 @@ class Estimator(object): else: init_op = None + def _unwrap_and_concat(value): + value = nest.flatten(self._distribution.unwrap(value)) + if len(value) != 1: + return array_ops.concat(value) + return value[0] + ready_op = self._distribution.call_for_each_tower( create_per_tower_ready_op, grouped_estimator_spec.scaffold) if ready_op is not None: - ready_op = self._distribution.group(ready_op) + ready_op = _unwrap_and_concat(ready_op) else: ready_op = None @@ -1207,8 +1295,7 @@ class Estimator(object): create_per_tower_ready_for_local_init_op, grouped_estimator_spec.scaffold) if ready_for_local_init_op is not None: - ready_for_local_init_op = self._distribution.group( - ready_for_local_init_op) + ready_for_local_init_op = _unwrap_and_concat(ready_for_local_init_op) else: ready_for_local_init_op = None @@ -1249,13 +1336,28 @@ class Estimator(object): training_chief_hooks = get_hooks_from_the_first_device( grouped_estimator_spec.training_chief_hooks) + # TODO(sourabhbajaj): Merge the two code paths once we can + # handle per device variables correctly in reduce and can output + # the loss scaler. + if is_tpu_strategy: + loss = self._distribution.unwrap( + self._distribution.reduce(distribute_lib.get_loss_reduction(), + tpu_result)[0])[0] + worker_hooks.append( + estimator_util.StrategyInitFinalizeHook( + self._distribution.get_initialization_ops, + self._distribution.get_finalize_ops)) + else: + loss = self._distribution.unwrap( + self._distribution.reduce(distribute_lib.get_loss_reduction(), + grouped_estimator_spec.loss, + destinations='/device:CPU:0'))[0] + distributed_train_op = grouped_estimator_spec.train_op + estimator_spec = model_fn_lib.EstimatorSpec( mode=grouped_estimator_spec.mode, - loss=self._distribution.unwrap( - self._distribution.reduce(distribute_lib.get_loss_reduction(), - grouped_estimator_spec.loss, - destinations='/device:CPU:0'))[0], - train_op=self._distribution.group(grouped_estimator_spec.train_op), + loss=loss, + train_op=self._distribution.group(distributed_train_op), training_hooks=training_hooks, training_chief_hooks=training_chief_hooks, scaffold=scaffold) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index a43b820f322d70093a5015457fea294e436daeea..2a0e4e761755e272a316ce2d326b0c0a51ecbaba 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -61,6 +62,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator @@ -1295,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase): dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint()) self.assertEqual(5, scores['global_step']) + def test_wrong_shape_throws_reasonable_error(self): + """Make sure we are helpful when model_fns change. See b/110263146.""" + def _get_model_fn(val=1): + def _model_fn(features, labels, mode): + del features, labels # unused + variables.Variable(val, name='weight') + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant([[1.]]), + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + return _model_fn + + model_fn_1 = _get_model_fn() + model_fn_2 = _get_model_fn(val=[1]) + + est1 = estimator.Estimator(model_fn=model_fn_1) + est1.train(dummy_input_fn, steps=5) + est2 = estimator.Estimator( + model_fn=model_fn_2, model_dir=est1.model_dir) + + expected_msg = 'Restoring from checkpoint failed.*a mismatch between' + with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg): + est2.train(dummy_input_fn, steps=1,) + def test_scaffold_is_used(self): def _model_fn_scaffold(features, labels, mode): @@ -2829,6 +2856,45 @@ class EstimatorExportTest(test.TestCase): # Clean up. gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_no_export_outputs(self): + """Ensure that an EstimatorSpec without outputs defined can be exported.""" + + def _model_fn(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn) + est.train(input_fn=dummy_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('no_export_outputs')) + export_dir = est.export_savedmodel( + export_dir_base, _get_serving_input_receiver_fn()) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + meta_graph = loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('weight' in graph_ops) + + sig_def = meta_graph.signature_def + self.assertEqual(len(sig_def), 1) + sig_outputs = sig_def[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs + self.assertEqual(sig_outputs['output'].name, 'Const:0') + class EstimatorHookOrderingTest(test.TestCase): @@ -2873,7 +2939,7 @@ class EstimatorHookOrderingTest(test.TestCase): class EstimatorIntegrationTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_complete_flow_with_a_simple_linear_model(self): def _model_fn(features, labels, mode): diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 2f439f765e6811335667b62437f7aafc934904dc..5769f5739c5877ceeb8bc7234896e96672a3127d 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -45,7 +45,6 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -123,8 +122,8 @@ def _create_ordered_io(keras_model, estimator_io, is_input=True): 'It needs to match one ' 'of the following: %s' % ('input' if is_input else 'output', key, ', '.join(keras_io_names))) - tensors = [_convert_tensor(estimator_io[io_name]) - for io_name in keras_io_names] + tensors = [_convert_tensor(estimator_io[io_name]) + for io_name in keras_io_names] return tensors else: # Plain array. @@ -446,7 +445,6 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt')) -@tf_export('keras.estimator.model_to_estimator') def model_to_estimator(keras_model=None, keras_model_path=None, custom_objects=None, @@ -455,7 +453,7 @@ def model_to_estimator(keras_model=None, """Constructs an `Estimator` instance from given keras model. For usage example, please see - @{$programmers_guide/estimators$creating_estimators_from_keras_models}. + @{$guide/estimators$creating_estimators_from_keras_models}. Args: keras_model: A compiled Keras model object. This argument is mutually diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index c60c7f63bacc810e447bcafe954c55cb49ede7e0..a9fd8f8e1a4259fece1a5996343970900c853ce0 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -23,7 +23,7 @@ import collections import six -from tensorflow.python.estimator.export.export_output import ExportOutput +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -99,7 +99,7 @@ class EstimatorSpec( ignored in eval and infer modes. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): predictions = ... loss = ... train_op = ... @@ -114,7 +114,7 @@ class EstimatorSpec( given mode. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): if (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL): loss = ... @@ -158,6 +158,8 @@ class EstimatorSpec( Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. + If no entry is provided, a default `PredictOutput` mapping to + `predictions` will be created. training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training. training_hooks: Iterable of `tf.train.SessionRunHook` objects to run @@ -232,29 +234,9 @@ class EstimatorSpec( _check_is_tensor_or_operation(metric_update, 'eval_metric_ops[{}]'.format(key)) - # Validate export_outputs. - if export_outputs is not None: - if not isinstance(export_outputs, dict): - raise TypeError('export_outputs must be dict, given: {}'.format( - export_outputs)) - for v in six.itervalues(export_outputs): - if not isinstance(v, ExportOutput): - raise TypeError( - 'Values in export_outputs must be ExportOutput objects. ' - 'Given: {}'.format(export_outputs)) - # Note export_outputs is allowed to be empty. - if len(export_outputs) == 1: - (key, value), = export_outputs.items() - if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - export_outputs[ - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value - if len(export_outputs) > 1: - if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - not in export_outputs): - raise ValueError( - 'Multiple export_outputs were provided, but none of them is ' - 'specified as the default. Do this by naming one of them with ' - 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') + # Validate the passed export outputs, or generate defaults. + if mode == ModeKeys.PREDICT: + export_outputs = _get_export_outputs(export_outputs, predictions) # Validate that all tensors and ops are from the default graph. default_graph = ops.get_default_graph() @@ -286,11 +268,11 @@ class EstimatorSpec( raise ValueError(error_message_template.format('train_op', train_op.name)) for key, value in list(six.iteritems(eval_metric_ops)): values = nest.flatten(value) - for value in values: - if value.graph is not default_graph: + for val in values: + if val.graph is not default_graph: raise ValueError(error_message_template.format( 'eval_metric_ops', - '{0}: {1}'.format(key, value.name))) + '{0}: {1}'.format(key, val.name))) # Validate hooks. training_chief_hooks = tuple(training_chief_hooks or []) @@ -334,6 +316,70 @@ class EstimatorSpec( return EstimatorSpec(*new_fields) +def _get_export_outputs(export_outputs, predictions): + """Validate export_outputs or create default export_outputs. + + Args: + export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict or None. + predictions: Predictions `Tensor` or dict of `Tensor`. + + Returns: + Valid export_outputs dict + + Raises: + TypeError: if export_outputs is not a dict or its values are not + ExportOutput instances. + """ + if export_outputs is None: + default_output = export_output_lib.PredictOutput(predictions) + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output} + + if not isinstance(export_outputs, dict): + raise TypeError('export_outputs must be dict, given: {}'.format( + export_outputs)) + for v in six.itervalues(export_outputs): + if not isinstance(v, export_output_lib.ExportOutput): + raise TypeError( + 'Values in export_outputs must be ExportOutput objects. ' + 'Given: {}'.format(export_outputs)) + + _maybe_add_default_serving_output(export_outputs) + + return export_outputs + + +def _maybe_add_default_serving_output(export_outputs): + """Add a default serving output to the export_outputs if not present. + + Args: + export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict. + + Returns: + export_outputs dict with default serving signature added if necessary + + Raises: + ValueError: if multiple export_outputs were provided without a default + serving key. + """ + if len(export_outputs) == 1: + (key, value), = export_outputs.items() + if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value + if len(export_outputs) > 1: + if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + not in export_outputs): + raise ValueError( + 'Multiple export_outputs were provided, but none of them is ' + 'specified as the default. Do this by naming one of them with ' + 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') + + return export_outputs + + class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ 'mode', 'predictions', diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index b7eeeb437cb4a624cdee552be3032364b18a8290..08e41fd4146e9254fc8cc7da6bc809e80d053a5b 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -592,6 +592,27 @@ class EstimatorSpecInferTest(test.TestCase): predictions=predictions, export_outputs=export_outputs) + def testDefaultExportOutputCreated(self): + """Ensure that a default PredictOutput is created for export.""" + with ops.Graph().as_default(), self.test_session(): + predictions = constant_op.constant(1.) + self._assertDefaultExportOutputForPredictions(predictions) + + def testDefaultExportOutputCreatedDict(self): + """Ensure that a default PredictOutput is created for export for dicts.""" + with ops.Graph().as_default(), self.test_session(): + predictions = {'loss': constant_op.constant(1.), + 'score': constant_op.constant(10.)} + self._assertDefaultExportOutputForPredictions(predictions) + + def _assertDefaultExportOutputForPredictions(self, predictions): + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.PREDICT, predictions=predictions) + + expected = export_output.PredictOutput(predictions).outputs + serving_output = spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + self.assertEqual(serving_output.outputs, expected) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index b948ce96e0c09c0537619366403658408cf17895..3d60c63b68968c98a00364948bd3de0581daadd4 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -25,6 +25,7 @@ import os import six from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat_internal @@ -484,6 +485,43 @@ class RunConfig(object): self._init_distributed_setting_from_environment_var(tf_config) + # Get session_config only for distributed mode (cluster_spec is present). + if not self._session_config and self._cluster_spec: + RunConfig._replace( + self, + allowed_properties_list=_DEFAULT_REPLACEABLE_LIST, + session_config=self._get_default_session_config()) + + def _get_default_session_config(self): + """Returns None or tf.ConfigProto instance with default device_filters set. + + Device filters are set such that chief/master and worker communicates with + only ps. session_config=None for evaluators or any other TaskType. + """ + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + + device_filters = None + if self._task_type == TaskType.MASTER: + device_filters = ['/job:ps', '/job:master'] + elif self._task_type == TaskType.CHIEF: + device_filters = ['/job:ps', '/job:chief'] + elif self._task_type == TaskType.WORKER: + device_filters = ['/job:ps', '/job:worker/task:%d' % self._task_id] + elif self._task_type == TaskType.PS: + device_filters = ['/job:ps', '/job:worker', '/job:master'] + else: + # If the task_type is `EVALUATOR` or something other than the ones in + # TaskType then don't set any device filters. + return None + + return config_pb2.ConfigProto( + allow_soft_placement=True, + graph_options=graph_opts, + device_filters=device_filters) + def _init_distributed_setting_from_environment_var(self, tf_config): """Initialize distributed properties based on `tf_config`.""" diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py index c8b12605e1aaad11e114e4ace63697b93f3b2b92..06df7cb9dd4ae3d167d622601e551079b64e80a2 100644 --- a/tensorflow/python/estimator/run_config_test.py +++ b/tensorflow/python/estimator/run_config_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import json from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import test @@ -290,6 +291,7 @@ class RunConfigDistributedSettingTest(test.TestCase): expected_num_worker_replicas=1, expected_num_ps_replicas=0) self.assertEqual(0, run_config.global_id_in_cluster) + self.assertIsNone(run_config.session_config, None) def test_session_master_for_local(self): tf_config = {'session_master': '_my_master'} @@ -1119,5 +1121,115 @@ class RunConfigModelDirTest(test.TestCase): _create_run_config_with_cluster_spec(tf_config) +class RunConfigSessionConfigTest(test.TestCase): + + def _assert_equal_session_config(self, session_config, + expected_device_filters): + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + expected_session_config = config_pb2.ConfigProto( + allow_soft_placement=True, + graph_options=graph_opts, + device_filters=expected_device_filters) + self.assertEqual(session_config, expected_session_config) + + def test_master_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.MASTER, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:master']) + + def test_chief_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:chief']) + + def test_worker_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.WORKER, + 'index': 1 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:worker/task:1']) + + def test_ps_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.PS, + 'index': 1 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:worker', '/job:master']) + + def test_evaluator_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.EVALUATOR, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self.assertIsNone(run_config.session_config) + + def test_other_type_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + 'other_type': ['host3:1', 'host4:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': 'other_type', + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self.assertIsNone(run_config.session_config) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 1572af579b964e8cf5cdb3d5d11a56d80b965b5c..57301010920be90c63e00594d686df3a09466c91 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -278,10 +278,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): supported distributed training configuration is between-graph replication. Overfitting: In order to avoid overfitting, it is recommended to set up the - training `input_fn` to shuffle the training data properly. It is also - recommended to train the model a little longer, say multiple epochs, before - performing evaluation, as the input pipeline starts from scratch for each - training. It is particularly important for local training and evaluation. + training `input_fn` to shuffle the training data properly. Stop condition: In order to support both distributed and non-distributed configuration reliably, the only supported stop condition for model @@ -470,6 +467,61 @@ class _StopAtSecsHook(session_run_hook.SessionRunHook): run_context.request_stop() +class _NewCheckpointListenerForEvaluate( + basic_session_run_hooks.CheckpointSaverListener): + """A saver listener to run evaluate with every checkpoint.""" + + def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener): + self._evaluator = evaluator + self._eval_throttle_secs = eval_throttle_secs + self._continuous_eval_listener = continuous_eval_listener + self.eval_result, self.export_results = None, None + + def begin(self): + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=self._eval_throttle_secs) + self._is_first_run = True + + def after_save(self, session, global_step_value): + del session # unused; required by signature. + # skip first run model is not trained yet. + if self._is_first_run: + self._is_first_run = False + return + + if not self._continuous_eval_listener.before_eval(): + logging.info('Exiting training and evaluation loop, as requested by ' + '_ContinuousEvalListener.before_eval.') + return True + if self._timer.should_trigger_for_step(global_step_value): + self._evaluate(global_step_value) # updates self.eval_result + if not self._continuous_eval_listener.after_eval(self.eval_result): + logging.info('Exiting evaluation, as requested by ' + '_ContinuousEvalListener.after_eval.') + return True + else: + # TODO(ispir): add remaining time in the log. + logging.info('Skip the current checkpoint eval due to throttle secs ' + '({} secs).'.format(self._eval_throttle_secs)) + + def end(self, session, global_step_value): + # Evaluate if the last step has not been evaluated, yet. + if global_step_value != self._timer.last_triggered_step(): + if self._continuous_eval_listener.before_eval(): + self._evaluate(global_step_value) + self._continuous_eval_listener.after_eval(self.eval_result) + + def _evaluate(self, global_step_value): + self._timer.update_last_triggered_step(global_step_value) + self.eval_result, self.export_results = ( + self._evaluator.evaluate_and_export()) + if self.eval_result.status != _EvalStatus.EVALUATED: + # This is unexpected; should never happen. + # Training should always end with a new checkpoint. + raise RuntimeError('There was no new checkpoint after the training. ' + 'Eval status: {}'.format(self.eval_result.status)) + + class _TrainingExecutor(object): """The executor to run `Estimator` training and evaluation. @@ -576,28 +628,6 @@ class _TrainingExecutor(object): def run_master(self): """Runs task master.""" - - class NewCheckpointListener( - basic_session_run_hooks.CheckpointSaverListener): - - def __init__(self, evaluator, eval_throttle_secs): - self._evaluator = evaluator - self._eval_throttle_secs = eval_throttle_secs - - def begin(self): - self._timer = basic_session_run_hooks.SecondOrStepTimer( - every_secs=self._eval_throttle_secs) - - def after_save(self, session, global_step_value): - del session # unused; required by signature. - - if self._timer.should_trigger_for_step(global_step_value): - self._timer.update_last_triggered_step(global_step_value) - self._evaluator.evaluate_and_export() - else: - logging.info('Skip the current checkpoint eval due to throttle secs ' - '({} secs).'.format(self._eval_throttle_secs)) - _assert_eval_spec(self._eval_spec) # Final export signal: For any eval result with global_step >= train @@ -617,16 +647,12 @@ class _TrainingExecutor(object): # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. saving_listeners = [ - NewCheckpointListener(evaluator, self._eval_spec.throttle_secs) + _NewCheckpointListenerForEvaluate(evaluator, + self._eval_spec.throttle_secs, + _ContinuousEvalListener()) ] self._start_distributed_training(saving_listeners=saving_listeners) - if not evaluator.is_final_export_triggered: - logging.info('Training has already ended. But the last eval is skipped ' - 'due to eval throttle_secs. Now evaluating the final ' - 'checkpoint.') - evaluator.evaluate_and_export() - def run_evaluator(self): """Runs task evaluator.""" # TODO(xiejw): To allow execution framework to add continuous eval listener. @@ -640,68 +666,33 @@ class _TrainingExecutor(object): def run_local(self): """Runs training and evaluation locally (non-distributed).""" - - def _should_stop_local_train(global_step): - if self._train_spec.max_steps is None: - return False - if global_step >= self._train_spec.max_steps: - return True - return False - _assert_eval_spec(self._eval_spec) - if self._eval_spec.throttle_secs <= 0: - raise ValueError('eval_spec.throttle_secs should be positive, given: {}.' - 'It is used do determine how long each training ' - 'iteration should go when train and evaluate ' - 'locally.'.format(self._eval_spec.throttle_secs)) - - stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) - train_hooks = ( - list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks)) + train_hooks = list(self._train_spec.hooks) + list(self._train_hooks) logging.info('Start train and evaluate loop. The evaluate will happen ' - 'after {} secs (eval_spec.throttle_secs) or training is ' - 'finished.'.format(self._eval_spec.throttle_secs)) + 'after every checkpoint. Checkpoint frequency is determined ' + 'based on RunConfig arguments: save_checkpoints_steps {} or ' + 'save_checkpoints_secs {}.'.format( + self._estimator.config.save_checkpoints_steps, + self._estimator.config.save_checkpoints_secs)) evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, self._train_spec.max_steps) - eval_result = _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT) - export_results = [] - - while True: - self._estimator.train( - input_fn=self._train_spec.input_fn, - max_steps=self._train_spec.max_steps, - hooks=train_hooks) - - if not self._continuous_eval_listener.before_eval(): - logging.info('Exiting training and evaluation loop, as requested by ' - '_ContinuousEvalListener.before_eval.') - break - - # Final export signal: For any eval result with global_step >= train - # max_steps, the evaluator will send the final export signal. The - # _should_stop_local_train will then end the while True as the stopping - # condition is satisfied (both checks use the same global_step value, - # i.e., no race condition) - eval_result, export_results = evaluator.evaluate_and_export() - - if eval_result.status != _EvalStatus.EVALUATED: - # This is unexpected; should never happen. - # Training should always end with a new checkpoint. - raise RuntimeError('There was no new checkpoint after the training. ' - 'Eval status: {}'.format(eval_result.status)) - - if not self._continuous_eval_listener.after_eval(eval_result): - logging.info('Exiting evaluation, as requested by ' - '_ContinuousEvalListener.after_eval.') - break + listener_for_eval = _NewCheckpointListenerForEvaluate( + evaluator, self._eval_spec.throttle_secs, + self._continuous_eval_listener) + saving_listeners = [listener_for_eval] + + self._estimator.train( + input_fn=self._train_spec.input_fn, + max_steps=self._train_spec.max_steps, + hooks=train_hooks, + saving_listeners=saving_listeners) - if _should_stop_local_train( - eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]): - break - return eval_result.metrics, export_results + eval_result = listener_for_eval.eval_result or _EvalResult( + status=_EvalStatus.MISSING_CHECKPOINT) + return eval_result.metrics, listener_for_eval.export_results def _start_std_server(self, config): """Creates, starts, and returns a server_lib.Server.""" diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 2c838db7a4de98d941752ce9d5ddf8f2b47a46f1..6bee7cbe83a5e9b623ea16ebe48cce93e27534e2 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -29,17 +29,21 @@ import time import numpy as np +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import exporter as exporter_lib +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -49,6 +53,7 @@ from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util from tensorflow.python.util import compat _DEFAULT_EVAL_STEPS = 100 @@ -885,7 +890,8 @@ class TrainingExecutorRunMasterTest(test.TestCase): # `after_save`. del args, kwargs saving_listeners[0].begin() - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=0) + saving_listeners[0].after_save(session=None, global_step_value=10) mock_est = test.mock.Mock( spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train) @@ -930,7 +936,10 @@ class TrainingExecutorRunMasterTest(test.TestCase): del args, kwargs saving_listeners[0].begin() - # Call three times. + # Call four times. + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=None) + mock_timer.should_trigger_for_step.return_value = True saving_listeners[0].after_save(session=None, global_step_value=None) @@ -979,14 +988,19 @@ class TrainingExecutorRunMasterTest(test.TestCase): del args, kwargs saving_listeners[0].begin() - # Call two times. + # Call tree times (one for first saving). mock_timer.should_trigger_for_step.return_value = True - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=0) + + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=125) - # The final ckpt is skipped by the timer. It will be picked up the final - # export check in the code. mock_timer.should_trigger_for_step.return_value = False - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=250) + + # At the end evaluate should be called even if throttle secs prevents it. + mock_timer.should_trigger_for_step.return_value = False + saving_listeners[0].end(session=None, global_step_value=300) mock_est.train = estimator_train mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] @@ -1566,28 +1580,31 @@ class StopAtSecsHookTest(test.TestCase): class TrainingExecutorRunLocalTest(test.TestCase): """Tests run_local of _TrainingExecutor.""" + def _model_fn(self, features, labels, mode): + del labels + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training_util.get_global_step(), 1) + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(0.), + train_op=train_op, + predictions=constant_op.constant([[10.]]), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + def _input_fn(self, repeat=True): + ds = dataset_ops.Dataset.from_tensors([1]) + if repeat: + return ds.repeat() + return ds + def unique_checkpoint_every_time_fn(self): return 'checkpoint_path_%s/' % random.random() - def test_send_stop_at_secs_to_train(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn - train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) - eval_spec = training.EvalSpec( - input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} - - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - executor.run_local() - - stop_hook = mock_est.train.call_args[1]['hooks'][-1] - self.assertIsInstance(stop_hook, training._StopAtSecsHook) - self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs) - - def test_runs_in_a_loop_until_max_steps(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + def test_runs_evaluate_with_every_new_checkpoint(self): + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.times_export_was_called = 0 mock_est.times_final_export_was_called = 0 @@ -1604,42 +1621,30 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.name = 'see_how_many_times_export_is_called' exporter.export = export - train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=22) eval_spec = training.EvalSpec( - input_fn=lambda: 1, - hooks=[_FakeHook()], - throttle_secs=100, + input_fn=lambda: self._input_fn(repeat=False), + throttle_secs=0, exporters=exporter) - # should be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() - self.assertEqual(3, mock_est.train.call_count) + self.assertEqual(1, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_was_called) self.assertEqual(1, mock_est.times_final_export_was_called) def test_runs_with_eval_listener_before_eval(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn - train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) - # should be called 2 times without the evallistener - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12) + eval_spec = training.EvalSpec(input_fn=lambda: self._input_fn(repeat=False)) + mock_est.evaluate.side_effect = [{_GLOBAL_STEP_KEY: train_spec.max_steps}] class _Listener(training._ContinuousEvalListener): @@ -1658,67 +1663,61 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(1, mock_est.train.call_count) self.assertEqual(0, mock_est.evaluate.call_count) - self.assertEqual(1, listener.call_count) def test_runs_with_eval_listener_after_eval(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) - train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) - # should be called 2 times without the evallistener - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=3000) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) class _Listener(training._ContinuousEvalListener): - def __init__(self, test_case): + def __init__(self): self.call_count = 0 - self._test_case = test_case def after_eval(self, eval_result): self.call_count += 1 - self._test_case.assertEqual( - train_spec.max_steps - 50, eval_result.metrics[_GLOBAL_STEP_KEY]) return False # Will stop the run_local after first eval. - listener = _Listener(test_case=self) + listener = _Listener() executor = training._TrainingExecutor( mock_est, train_spec, eval_spec, continuous_eval_listener=listener) - executor.run_local() + metrics, _ = executor.run_local() # pylint: disable=assignment-from-no-return self.assertEqual(1, mock_est.train.call_count) self.assertEqual(1, mock_est.evaluate.call_count) self.assertEqual(1, listener.call_count) + # Should be less than max_steps since listener did early stopping. + self.assertLess(metrics[_GLOBAL_STEP_KEY], train_spec.max_steps) def test_handles_no_new_checkpoint_found(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = ( - 'no_new_checkpoints_after_the_first_train_step') + est = estimator_lib.Estimator( + model_fn=self._model_fn, + # disable saving checkpoint + config=run_config_lib.RunConfig( + save_checkpoints_steps=None, save_checkpoints_secs=None)) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) - # It was going to be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + input_fn=lambda: self._input_fn(repeat=False), + hooks=[_FakeHook()], + throttle_secs=100) - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): + executor = training._TrainingExecutor(est, train_spec, eval_spec) + with self.assertRaisesRegexp(ValueError, + 'There should be a CheckpointSaverHook'): executor.run_local() def test_final_export_is_true_in_the_end(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.times_export_fn_was_called = 0 mock_est.times_the_final_export_was_true = 0 @@ -1734,37 +1733,29 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.export = export train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, - hooks=[_FakeHook()], - throttle_secs=100, + input_fn=lambda: self._input_fn(repeat=False), + throttle_secs=0, exporters=exporter) - # should be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() - self.assertEqual(3, mock_est.train.call_count) - self.assertEqual(3, mock_est.evaluate.call_count) - self.assertEqual(3, mock_est.times_export_fn_was_called) + self.assertEqual(1, mock_est.train.call_count) + self.assertEqual(2, mock_est.evaluate.call_count) + self.assertEqual(2, mock_est.times_export_fn_was_called) self.assertEqual(1, mock_est.times_the_final_export_was_true) def test_train_and_evaluate_args(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = 'checkpoint_path/' + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} + input_fn=lambda: self._input_fn(repeat=False), + steps=2, + hooks=[_FakeHook()], + name='local_eval') executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() @@ -1773,11 +1764,11 @@ class TrainingExecutorRunLocalTest(test.TestCase): name=eval_spec.name, input_fn=eval_spec.input_fn, steps=eval_spec.steps, - checkpoint_path='checkpoint_path/', + checkpoint_path=est.latest_checkpoint(), hooks=eval_spec.hooks) train_args = mock_est.train.call_args[1] - self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1])) + self.assertEqual(list(train_spec.hooks), list(train_args['hooks'])) self.assertEqual(train_spec.input_fn, train_args['input_fn']) self.assertEqual(train_spec.max_steps, train_args['max_steps']) @@ -1812,25 +1803,11 @@ class TrainingExecutorRunLocalTest(test.TestCase): if not isinstance(h, training._StopAtSecsHook) ]) - def test_errors_out_if_throttle_secs_is_zero(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0) - - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - with self.assertRaisesRegexp(ValueError, 'throttle_secs'): - executor.run_local() - def test_that_export_is_called_with_run_local(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_train_spec.max_steps = 200 - mock_est.evaluate.return_value = { - _GLOBAL_STEP_KEY: mock_train_spec.max_steps - } - # _validate_hooks would have made sure that train_spec.hooks is [], when - # None were passed. - mock_train_spec.hooks = [] + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12) + mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} def export(estimator, *args, **kwargs): del args, kwargs @@ -1842,13 +1819,13 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.export = export eval_spec = training.EvalSpec( - input_fn=lambda: 1, + input_fn=lambda: self._input_fn(repeat=False), steps=2, start_delay_secs=0, throttle_secs=213, exporters=exporter) - executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) + executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) # pylint: disable=assignment-from-no-return _, export_results = executor.run_local() # pylint: enable=assignment-from-no-return @@ -1857,9 +1834,13 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(export_results, ['path_to_export']) def test_errors_out_if_evaluate_returns_empty_dict(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = {} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) @@ -1867,18 +1848,26 @@ class TrainingExecutorRunLocalTest(test.TestCase): executor.run_local() def test_errors_out_if_evaluate_returns_non_dict(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = 123 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): executor.run_local() def test_errors_out_if_evaluate_returns_dict_without_global_step(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = {'loss': 123} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) @@ -1887,19 +1876,21 @@ class TrainingExecutorRunLocalTest(test.TestCase): executor.run_local() def test_train_and_evaluate_return_metrics(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = 'checkpoint_path/' + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} + input_fn=lambda: self._input_fn(repeat=False), + steps=2, + hooks=[_FakeHook()], + name='local_eval') executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) # pylint: disable=assignment-from-no-return metrics, _ = executor.run_local() # pylint: enable=assignment-from-no-return - self.assertEqual(metrics['global_step'], 300) + self.assertEqual(metrics['global_step'], 12) class TrainAndEvaluateRunTest(test.TestCase): @@ -2096,7 +2087,7 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): # max_steps should be larger than save_summary_steps max_steps = 10 - save_summary_steps = 2 + save_summary_steps = 9 data = np.linspace( 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) @@ -2104,24 +2095,20 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) # learn y = x - train_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - y=y_data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - y=y_data, - batch_size=batch_size, - num_epochs=1, - shuffle=False) - - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - batch_size=batch_size, - shuffle=False) + def train_input_fn(): + return dataset_ops.Dataset.from_tensor_slices(({ + 'x': x_data + }, y_data)).batch(batch_size).repeat().shuffle(1000) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensor_slices(({ + 'x': x_data + }, y_data)).batch(batch_size) + + def predict_input_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + 'x': x_data + }).batch(batch_size) feature_columns = [ feature_column.numeric_column('x', shape=(input_dimension,))] @@ -2137,9 +2124,11 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): max_steps=max_steps) eval_spec = training.EvalSpec( - name=eval_name, input_fn=eval_input_fn, steps=None, + name=eval_name, + input_fn=eval_input_fn, + steps=None, exporters=self._get_exporter(exporter_name, feature_columns), - throttle_secs=2) + throttle_secs=0) training.train_and_evaluate(est, train_spec, eval_spec) @@ -2148,15 +2137,12 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): # Examine the training events. Use a range to check global step to avoid # flakyness due to global step race condition. - training_loss, training_global_step = self._extract_loss_and_global_step( - est.model_dir) + training_loss, _ = self._extract_loss_and_global_step(est.model_dir) self.assertIsNotNone(training_loss) - self.assertTrue( - max_steps - save_summary_steps < training_global_step <= max_steps) # Examine the eval events. The global step should be accurate. eval_loss, eval_global_step = self._extract_loss_and_global_step( - event_folder=os.path.join(est.model_dir, 'eval_' + eval_name)) + event_folder=est.eval_dir(eval_name)) self.assertIsNotNone(eval_loss) self.assertEqual(max_steps, eval_global_step) diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index 924ca309ff0455d3bb06be61ce65bb0a61e84fb0..d4a75478d53f5b3dc8e66df98a78b51a6d25aab8 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -22,6 +22,7 @@ from __future__ import print_function import os import time +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training @@ -129,3 +130,24 @@ class _DatasetInitializerHook(training.SessionRunHook): def after_create_session(self, session, coord): del coord session.run(self._initializer) + + +class StrategyInitFinalizeHook(training.SessionRunHook): + """Creates a SessionRunHook that initializes and shutsdown devices.""" + + def __init__(self, initialization_fn, finalize_fn): + self._initialization_fn = initialization_fn + self._finalize_fn = finalize_fn + + def begin(self): + self._init_ops = self._initialization_fn() + self._finalize_ops = self._finalize_fn() + + def after_create_session(self, session, coord): + logging.info('Initialize system') + session.run(self._init_ops, + options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) + + def end(self, session): + logging.info('Finalize system.') + session.run(self._finalize_ops) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index a58c5aabbe4b57b38ba7894900cbf390d3fe7669..40219e4b342de8e69f0b45f32a1f7b3eccfa3b80 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -452,13 +452,15 @@ def linear_model(features, ValueError: if an item in `feature_columns` is neither a `_DenseColumn` nor `_CategoricalColumn`. """ + with variable_scope.variable_scope(None, 'linear_model') as vs: + model_name = _strip_leading_slashes(vs.name) linear_model_layer = _LinearModel( feature_columns=feature_columns, units=units, sparse_combiner=sparse_combiner, weight_collections=weight_collections, trainable=trainable, - name='linear_model') + name=model_name) retval = linear_model_layer(features) # pylint: disable=not-callable if cols_to_vars is not None: cols_to_vars.update(linear_model_layer.cols_to_vars()) @@ -466,13 +468,25 @@ def linear_model(features, def _add_to_collections(var, weight_collections): - # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` - # so that we don't have to do this check. - if isinstance(var, variables.PartitionedVariable): - for constituent_var in list(var): - ops.add_to_collections(weight_collections, constituent_var) - else: - ops.add_to_collections(weight_collections, var) + """Adds a var to the list of weight_collections provided. + + Handles the case for partitioned and non-partitioned variables. + + Args: + var: A variable or Partitioned Variable. + weight_collections: List of collections to add variable to. + """ + for weight_collection in weight_collections: + # The layer self.add_variable call already adds it to GLOBAL_VARIABLES. + if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES: + continue + # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` + # so that we don't have to do this check. + if isinstance(var, variables.PartitionedVariable): + for constituent_var in list(var): + ops.add_to_collection(weight_collection, constituent_var) + else: + ops.add_to_collection(weight_collection, var) class _FCLinearWrapper(base.Layer): @@ -583,6 +597,8 @@ class _LinearModel(training.Model): self._feature_columns = _normalize_feature_columns( feature_columns) self._weight_collections = list(weight_collections or []) + if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections: + self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections: self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) @@ -971,7 +987,12 @@ def shared_embedding_columns( ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified. ValueError: if `initializer` is specified and is not callable. + RuntimeError: if eager execution is enabled. """ + if context.executing_eagerly(): + raise RuntimeError('shared_embedding_columns are not supported when eager ' + 'execution is enabled.') + if (dimension is None) or (dimension < 1): raise ValueError('Invalid dimension {}.'.format(dimension)) if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): @@ -1016,16 +1037,6 @@ def shared_embedding_columns( shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) shared_embedding_collection_name += '_shared_embedding' - # Create the state (_SharedEmbeddingColumnLayer) here. - embedding_shape = num_buckets, dimension - - shared_embedding_column_layer = _EmbeddingColumnLayer( - embedding_shape=embedding_shape, - initializer=initializer, - weight_collections=[], - trainable=trainable, - name=shared_embedding_collection_name) - result = [] for column in categorical_columns: result.append( @@ -1034,16 +1045,12 @@ def shared_embedding_columns( initializer=initializer, dimension=dimension, combiner=combiner, - var_scope_name=shared_embedding_collection_name, + shared_embedding_collection_name=shared_embedding_collection_name, ckpt_to_load_from=ckpt_to_load_from, tensor_name_in_ckpt=tensor_name_in_ckpt, max_norm=max_norm, trainable=trainable)) - for single_result in result: - single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access - single_result._set_all_columns(result) # pylint: disable=protected-access - return result @@ -1863,11 +1870,8 @@ class _EmbeddingColumnLayer(base.Layer): dtype=dtypes.float32, initializer=self._initializer, trainable=self.trainable) - # self.add_variable already appends to GLOBAL_VARIABLES collection. if self._weight_collections and not context.executing_eagerly(): - for weight_collection in self._weight_collections: - if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES: - _add_to_collections(self._embedding_weight_var, [weight_collection]) + _add_to_collections(self._embedding_weight_var, self._weight_collections) self.built = True def call(self, _): @@ -2649,8 +2653,8 @@ class _SharedEmbeddingColumn( collections.namedtuple( '_SharedEmbeddingColumn', ('categorical_column', 'dimension', 'combiner', 'initializer', - 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt', - 'max_norm', 'trainable'))): + 'shared_embedding_collection_name', 'ckpt_to_load_from', + 'tensor_name_in_ckpt', 'max_norm', 'trainable'))): """See `embedding_column`.""" @property @@ -2661,7 +2665,7 @@ class _SharedEmbeddingColumn( @property def _var_scope_name(self): - return self.var_scope_name + return self.shared_embedding_collection_name @property def _parse_example_spec(self): @@ -2670,22 +2674,6 @@ class _SharedEmbeddingColumn( def _transform_feature(self, inputs): return inputs.get(self.categorical_column) - def _set_layer(self, layer): - self._layer = layer - - def _set_all_columns(self, all_columns): - self._all_columns = all_columns - - def _reset_config(self): - config = self._layer.get_config() - config['embedding_shape'] = ( - self.categorical_column._num_buckets, # pylint: disable=protected-access - self.dimension) - config['initializer'] = self.initializer - self._layer = self._layer.__class__.from_config(config) - for column in self._all_columns: - column._set_layer(self._layer) # pylint: disable=protected-access - @property def _variable_shape(self): if not hasattr(self, '_shape'): @@ -2707,19 +2695,38 @@ class _SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - self._layer.set_weight_collections(weight_collections) - embedding_weights = self._layer( - None, scope=variable_scope.get_variable_scope()) - # If we're in graph mode and this is called with a different graph, - # then we should reset. - if not context.executing_eagerly() and ( - ops.get_default_graph() != - _get_graph_for_variable(embedding_weights)): - self._reset_config() - self._layer.set_weight_collections(weight_collections) - embedding_weights = self._layer( - None, scope=variable_scope.get_variable_scope()) - + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + shared_embedding_collection = ops.get_collection( + self.shared_embedding_collection_name) + if shared_embedding_collection: + if len(shared_embedding_collection) > 1: + raise ValueError( + 'Collection {} can only contain one variable. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(shared_embedding_collection)) + embedding_weights = shared_embedding_collection[0] + if embedding_weights.get_shape() != embedding_shape: + raise ValueError( + 'Shared embedding collection {} contains variable {} of ' + 'unexpected shape {}. Expected shape is {}. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(self.shared_embedding_collection_name, + embedding_weights.name, + embedding_weights.get_shape(), embedding_shape)) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + ops.add_to_collection(self.shared_embedding_collection_name, + embedding_weights) if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): @@ -3579,8 +3586,3 @@ class _SequenceCategoricalColumn( weight_tensor, shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) - - -# TODO(xiejw): Remove the following alias once call sites are updated. -_clean_feature_columns = _normalize_feature_columns -_to_sparse_input = _to_sparse_input_and_drop_ignore_values diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 627430d6bc5995cf054482ac3004098b8a2472ab..511205451cdee707d80993bd37eaad395625e773 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -1257,14 +1257,14 @@ class CrossedColumnTest(test.TestCase): }, (crossed,)) -def get_linear_model_bias(): - with variable_scope.variable_scope('linear_model', reuse=True): +def get_linear_model_bias(name='linear_model'): + with variable_scope.variable_scope(name, reuse=True): return variable_scope.get_variable('bias_weights') -def get_linear_model_column_var(column): +def get_linear_model_column_var(column, name='linear_model'): return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, - 'linear_model/' + column.name)[0] + name + '/' + column.name)[0] def get_keras_linear_model_predictions(features, @@ -1928,6 +1928,27 @@ class LinearModelTest(test.TestCase): with self.assertRaisesOpError('Feature .* cannot have rank 0'): sess.run(net, feed_dict={features['price']: np.array(1)}) + def test_multiple_linear_models(self): + price = fc.numeric_column('price') + with ops.Graph().as_default(): + features1 = {'price': [[1.], [5.]]} + features2 = {'price': [[2.], [10.]]} + predictions1 = fc.linear_model(features1, [price]) + predictions2 = fc.linear_model(features2, [price]) + bias1 = get_linear_model_bias(name='linear_model') + bias2 = get_linear_model_bias(name='linear_model_1') + price_var1 = get_linear_model_column_var(price, name='linear_model') + price_var2 = get_linear_model_column_var(price, name='linear_model_1') + with _initialized_session() as sess: + self.assertAllClose([0.], bias1.eval()) + sess.run(price_var1.assign([[10.]])) + sess.run(bias1.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions1.eval()) + self.assertAllClose([0.], bias2.eval()) + sess.run(price_var2.assign([[10.]])) + sess.run(bias2.assign([5.])) + self.assertAllClose([[25.], [105.]], predictions2.eval()) + class _LinearModelTest(test.TestCase): @@ -2586,7 +2607,7 @@ class _LinearModelTest(test.TestCase): class InputLayerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_retrieving_input(self): features = {'a': [0.]} input_layer = InputLayer(fc.numeric_column('a')) @@ -5329,9 +5350,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertIsNone(embedding_column_a.ckpt_to_load_from) self.assertIsNone(embedding_column_b.ckpt_to_load_from) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_b.var_scope_name) + embedding_column_b.shared_embedding_collection_name) self.assertIsNone(embedding_column_a.tensor_name_in_ckpt) self.assertIsNone(embedding_column_b.tensor_name_in_ckpt) self.assertIsNone(embedding_column_a.max_norm) @@ -5378,9 +5399,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('my_combiner', embedding_column_b.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('shared_embedding_collection_name', - embedding_column_b.var_scope_name) + embedding_column_b.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) @@ -5431,7 +5452,7 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column_a.dimension) self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) self.assertEqual(42., embedding_column_a.max_norm) diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..9ccae761471e24ddb1d4d6acd89ebcc9650d1320 --- /dev/null +++ b/tensorflow/python/framework/error_interpolation.py @@ -0,0 +1,92 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Function for interpolating formatted errors from the TensorFlow runtime. + +Exposes the function `interpolate` to interpolate messages with tags of the form +^^type:name:format^^. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import itertools +import re +import string + +import six + +_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" +_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+" +_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format( + name=_NAME_REGEX, fmt=_FORMAT_REGEX) +_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) +_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) + +_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"]) + + +def _parse_message(message): + """Parses the message. + + Splits the message into separators and tags. Tags are named tuples + representing the string ^^type:name:format^^ and they are separated by + separators. For example, in + "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and + three separators. The separators are the numeric characters. + + Args: + message: String to parse + + Returns: + (list of separator strings, list of _ParseTags). + + For example, if message is "123^^node:Foo:${file}^^456" then this function + returns (["123", "456"], [_ParseTag("node", "Foo", "${file}")]) + """ + seps = [] + tags = [] + pos = 0 + while pos < len(message): + match = re.match(_INTERPOLATION_PATTERN, message[pos:]) + if match: + seps.append(match.group(1)) + tags.append(_ParseTag(match.group(3), match.group(4), match.group(5))) + pos += match.end() + else: + break + seps.append(message[pos:]) + return seps, tags + + +# TODO(jtkeeling): Modify to actually interpolate format strings rather than +# echoing them. +def interpolate(error_message): + """Interpolates an error message. + + The error message can contain tags of the form ^^type:name:format^^ which will + be replaced. + + Args: + error_message: A string to interpolate. + + Returns: + The string with tags of the form ^^type:name:format^^ interpolated. + """ + seps, tags = _parse_message(error_message) + subs = [string.Template(tag.format).safe_substitute({}) for tag in tags] + return "".join( + itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue=""))) diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ad448deb622cb6a3d24e502d7238d3f614d5af4d --- /dev/null +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -0,0 +1,49 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.errors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import error_interpolation +from tensorflow.python.platform import test + + +class InterpolateTest(test.TestCase): + + def testNothingToDo(self): + normal_string = "This is just a normal string" + interpolated_string = error_interpolation.interpolate(normal_string) + self.assertEqual(interpolated_string, normal_string) + + def testOneTag(self): + one_tag_string = "^^node:Foo:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string) + self.assertEqual(interpolated_string, "${file}") + + def testTwoTagsNoSeps(self): + two_tags_no_seps = "^^node:Foo:${file}^^^^node:Bar:${line}^^" + interpolated_string = error_interpolation.interpolate(two_tags_no_seps) + self.assertEqual(interpolated_string, "${file}${line}") + + def testTwoTagsWithSeps(self): + two_tags_with_seps = "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789" + interpolated_string = error_interpolation.interpolate(two_tags_with_seps) + self.assertEqual(interpolated_string, "123${file}456${line}789") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 002a3d3be5dee5b64e0e227386be5697ae0598c6..6525607faea62a461ee38fa0393ac29b809bb9b6 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -23,6 +23,7 @@ from __future__ import print_function import collections import hashlib +import sys from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 @@ -33,6 +34,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat @@ -40,6 +42,9 @@ from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access + class Defun(object): """Decorator used to define TensorFlow functions. diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py index 4fecc41343f26291dac8455f6c972a755b65ecfc..46c9c4c14adc7d4adeb11b45210cb296acb55086 100644 --- a/tensorflow/python/framework/function_def_to_graph.py +++ b/tensorflow/python/framework/function_def_to_graph.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import versions_pb2 @@ -25,6 +27,10 @@ from tensorflow.python.framework import function from tensorflow.python.framework import importer from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import versions +from tensorflow.python.ops import cond_v2_impl + +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=protected-access def function_def_to_graph(fdef, input_shapes=None): diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 72eb7e0eeb73fb1f8725ab2cbd4182e543c79b9f..699d2b70d176db7718a6e480f9f7b08a65ae6a8e 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -407,11 +407,11 @@ def import_graph_def(graph_def, _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) - # _ProcessNewOps mutates the new operations. _lock ensures a Session.run - # call cannot occur between creating the TF_Operations in the + # _ProcessNewOps mutates the new operations. _mutation_lock ensures a + # Session.run call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. - with graph._lock: # pylint: disable=protected-access + with graph._mutation_lock(): # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: results = c_api.TF_GraphImportGraphDefWithResults( diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index b440cde3ad4268b9f53da77af0c570f7c0f51d65..cf0b1e36fb3f02c85873a0da81dc056d2fbd5f6a 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import lock_util from tensorflow.python.util import tf_contextlib from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -63,7 +64,7 @@ from tensorflow.python.util.tf_export import tf_export # Temporary global switches determining if we should enable the work-in-progress # calls to the C API. These will be removed once all functionality is supported. _USE_C_API = True -_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0" +_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0" def tensor_id(tensor): @@ -2599,6 +2600,10 @@ def _name_from_scope_name(name): return name[:-1] if (name and name[-1] == "/") else name +_MUTATION_LOCK_GROUP = 0 +_SESSION_RUN_LOCK_GROUP = 1 + + @tf_export("Graph") class Graph(object): """A TensorFlow computation, represented as a dataflow graph. @@ -2648,20 +2653,21 @@ class Graph(object): def __init__(self): """Creates a new, empty Graph.""" - # Protects core state that can be returned via public accessors, as well as - # synchronizes Session.run calls with methods that create and mutate ops - # (e.g. Graph.create_op()). This synchronization is necessary because it's - # illegal to modify an operation after it's been run. Thread-safety is - # provided on a best-effort basis to support buggy programs, and is not - # guaranteed by the public `tf.Graph` API. - # - # The lock must be reentrant because create_op can be called recursively due - # to control flow. Without a reentrant lock, many methods would also need a - # "locked" version or parameter (including generated code). + # Protects core state that can be returned via public accessors. + # Thread-safety is provided on a best-effort basis to support buggy + # programs, and is not guaranteed by the public `tf.Graph` API. # # NOTE(mrry): This does not protect the various stacks. A warning will # be reported if these are used from multiple threads self._lock = threading.RLock() + # The group lock synchronizes Session.run calls with methods that create + # and mutate ops (e.g. Graph.create_op()). This synchronization is + # necessary because it's illegal to modify an operation after it's been run. + # The group lock allows any number of threads to mutate ops at the same time + # but if any modification is going on, all Session.run calls have to wait. + # Similarly, if one or more Session.run calls are going on, all mutate ops + # have to wait until all Session.run calls have finished. + self._group_lock = lock_util.GroupLock(num_groups=2) self._nodes_by_id = dict() # GUARDED_BY(self._lock) self._next_id_counter = 0 # GUARDED_BY(self._lock) self._nodes_by_name = dict() # GUARDED_BY(self._lock) @@ -3192,9 +3198,9 @@ class Graph(object): input_ops = set([t.op for t in inputs]) control_inputs = self._control_dependencies_for_inputs(input_ops) - # _create_op_helper mutates the new Operation. _lock ensures a Session.run - # call cannot occur between creating and mutating the op. - with self._lock: + # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a + # Session.run call cannot occur between creating and mutating the op. + with self._mutation_lock(): ret = Operation( node_def, self, @@ -4727,6 +4733,20 @@ class Graph(object): else: self._graph_control_dependencies_stack = control_dependencies + def _mutation_lock(self): + """Returns a lock to guard code that creates & mutates ops. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_MUTATION_LOCK_GROUP) + + def _session_run_lock(self): + """Returns a lock to guard code for Session.run. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) + # TODO(agarwal): currently device directives in an outer eager scope will not # apply to inner graph mode code. Fix that. @@ -5155,7 +5175,8 @@ def init_scope(): @tf_export("enable_eager_execution") -def enable_eager_execution(config=None, device_policy=None, +def enable_eager_execution(config=None, + device_policy=None, execution_mode=None): """Enables eager execution for the lifetime of this program. @@ -5215,6 +5236,31 @@ def enable_eager_execution(config=None, device_policy=None, TensorFlow graph, or if options provided conflict with a previous call to this function. """ + return enable_eager_execution_internal( + config, device_policy, execution_mode, None) + + +def enable_eager_execution_internal(config=None, + device_policy=None, + execution_mode=None, + server_def=None): + """Enables eager execution for the lifetime of this program. + + Most of the doc string for enable_eager_execution is relevant here as well. + Args: + config: See enable_eager_execution doc string + device_policy: See enable_eager_execution doc string + execution_mode: See enable_eager_execution doc string + server_def: (Optional.) A tensorflow::ServerDef proto. + Enables execution on remote devices. GrpcServers need to be started by + creating an identical server_def to this, and setting the appropriate + task_indexes, so that the servers can communicate. It will then be + possible to execute operations on remote devices. + + Raises: + ValueError + + """ if config is not None and not isinstance(config, config_pb2.ConfigProto): raise TypeError( "config must be a tf.ConfigProto, but got %s" % type(config)) @@ -5242,7 +5288,8 @@ def enable_eager_execution(config=None, device_policy=None, context._context = context.Context( config=config, device_policy=device_policy, - execution_mode=execution_mode) + execution_mode=execution_mode, + server_def=server_def) elif ((config is not None and config is not context._context._config) or (device_policy is not None and device_policy is not context._context._device_policy) or diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 750df4d8e3926fbaee7a38978457e448c21d64c7..150100d771bb41d3693d39dc6fa19baa40da4c04 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1690,7 +1690,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): # e should be dominated by c. self.assertEqual(e.op.control_inputs, []) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEager(self): def future(): future.calls += 1 @@ -1875,7 +1875,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): class OpScopeTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNames(self): with ops.name_scope("foo") as foo: self.assertEqual("foo/", foo) @@ -1906,7 +1906,7 @@ class OpScopeTest(test_util.TensorFlowTestCase): with ops.name_scope("a//b/c") as foo10: self.assertEqual("a//b/c/", foo10) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerDefaultScopeName(self): with ops.name_scope(None, "default") as scope: self.assertEqual(scope, "default/") diff --git a/tensorflow/python/framework/random_seed_test.py b/tensorflow/python/framework/random_seed_test.py index 194492268631abfa911bd45f13a302c09a2c8bda..6696bffc6c553f3fcf458f52cb9cd386e2711ff4 100644 --- a/tensorflow/python/framework/random_seed_test.py +++ b/tensorflow/python/framework/random_seed_test.py @@ -26,7 +26,7 @@ from tensorflow.python.platform import test class RandomSeedTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandomSeed(self): test_cases = [ # Each test case is a tuple with input to get_seed: diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 35fff80c61b98e7603d3b7b5df3cabdb59059a72..d6edc1364369e1b4d06093879571cdb4e9ffe409 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -941,7 +941,7 @@ class ConstantValueTest(test.TestCase): class ConstantValueAsShapeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstant(self): np_val = np.random.rand(3).astype(np.int32) tf_val = constant_op.constant(np_val) @@ -954,13 +954,13 @@ class ConstantValueAsShapeTest(test.TestCase): tensor_shape.TensorShape([]), tensor_util.constant_value_as_shape(tf_val)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) c_val = tensor_util.constant_value_as_shape(tf_val) self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMinusOneBecomesNone(self): tf_val = constant_op.constant([-1, 1, -1], shape=[3]) c_val = tensor_util.constant_value_as_shape(tf_val) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5582b14249599f96453c0686fd89d56d4985531d..2bc2a189fa8e825613ca834e2c06ea916074d455 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -27,6 +27,7 @@ import random import re import tempfile import threading +import unittest import numpy as np import six @@ -61,13 +62,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export @@ -414,8 +415,28 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections} for _ in range(3): f(self, **kwargs) + # Note that gc.get_objects misses anything that isn't subject to garbage + # collection (C types). Collections are a common source of leaks, so we + # test for collection sizes explicitly. + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") + % (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) @@ -552,14 +573,14 @@ def assert_no_garbage_created(f): def run_all_in_graph_and_eager_modes(cls): """Execute all test methods in the given class with and without eager.""" - base_decorator = run_in_graph_and_eager_modes() + base_decorator = run_in_graph_and_eager_modes for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): setattr(cls, name, base_decorator(value)) return cls -def run_in_graph_and_eager_modes(__unused__=None, +def run_in_graph_and_eager_modes(func=None, config=None, use_gpu=True, reset_test=True, @@ -577,7 +598,7 @@ def run_in_graph_and_eager_modes(__unused__=None, ```python class MyTests(tf.test.TestCase): - @run_in_graph_and_eager_modes() + @run_in_graph_and_eager_modes def test_foo(self): x = tf.constant([1, 2]) y = tf.constant([3, 4]) @@ -594,7 +615,9 @@ def run_in_graph_and_eager_modes(__unused__=None, Args: - __unused__: Prevents silently skipping tests. + func: function to be annotated. If `func` is None, this method returns a + decorator the can be applied to a function. If `func` is not None this + returns the decorator applied to `func`. config: An optional config_pb2.ConfigProto to use to configure the session when executing graphs. use_gpu: If True, attempt to run as many operations as possible on GPU. @@ -616,20 +639,19 @@ def run_in_graph_and_eager_modes(__unused__=None, eager execution enabled. """ - assert not __unused__, "Add () after run_in_graph_and_eager_modes." - def decorator(f): - def decorated(self, **kwargs): - with context.graph_mode(): - with self.test_session(use_gpu=use_gpu): - f(self, **kwargs) + if tf_inspect.isclass(f): + raise ValueError( + "`run_test_in_graph_and_eager_modes` only supports test methods. " + "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?") - if reset_test: - # This decorator runs the wrapped test twice. - # Reset the test environment between runs. - self.tearDown() - self._tempdir = None - self.setUp() + def decorated(self, **kwargs): + try: + with context.graph_mode(): + with self.test_session(use_gpu=use_gpu, config=config): + f(self, **kwargs) + except unittest.case.SkipTest: + pass def run_eagerly(self, **kwargs): if not use_gpu: @@ -644,10 +666,20 @@ def run_in_graph_and_eager_modes(__unused__=None, assert_no_garbage_created(run_eagerly)) with context.eager_mode(): + if reset_test: + # This decorator runs the wrapped test twice. + # Reset the test environment between runs. + self.tearDown() + self._tempdir = None + self.setUp() + run_eagerly(self, **kwargs) return decorated + if func is not None: + return decorator(func) + return decorator @@ -830,14 +862,13 @@ class TensorFlowTestCase(googletest.TestCase): def _eval_tensor(self, tensor): if tensor is None: return None - elif isinstance(tensor, ops.EagerTensor): - return tensor.numpy() - elif isinstance(tensor, resource_variable_ops.ResourceVariable): - return tensor.read_value().numpy() elif callable(tensor): return self._eval_helper(tensor()) else: - raise ValueError("Unsupported type %s." % type(tensor)) + try: + return tensor.numpy() + except AttributeError as e: + six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) def _eval_helper(self, tensors): if tensors is None: diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 0178908bcc9c0613353e3beea8e1eb11638f9531..122c14c8473f133f6a3bed1e6297394eaa1b845c 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -569,7 +569,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(a_np_rand, b_np_rand) self.assertEqual(a_rand, b_rand) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_callable_evaluate(self): def model(): return resource_variable_ops.ResourceVariable( @@ -578,7 +578,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): with context.eager_mode(): self.assertEqual(2, self.evaluate(model)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_tensors_evaluate(self): expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}} nested = {"a": constant_op.constant(1), @@ -588,6 +588,27 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(expected, self.evaluate(nested)) + def test_run_in_graph_and_eager_modes(self): + l = [] + def inc(self, with_brackets): + del self # self argument is required by run_in_graph_and_eager_modes. + mode = "eager" if context.executing_eagerly() else "graph" + with_brackets = "with_brackets" if with_brackets else "without_brackets" + l.append((with_brackets, mode)) + + f = test_util.run_in_graph_and_eager_modes(inc) + f(self, with_brackets=False) + f = test_util.run_in_graph_and_eager_modes()(inc) + f(self, with_brackets=True) + + self.assertEqual(len(l), 4) + self.assertEqual(set(l), { + ("with_brackets", "graph"), + ("with_brackets", "eager"), + ("without_brackets", "graph"), + ("without_brackets", "eager"), + }) + def test_get_node_def_from_graph(self): graph_def = graph_pb2.GraphDef() node_foo = graph_def.node.add() @@ -595,6 +616,55 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def)) + def test_run_in_eager_and_graph_modes_test_class(self): + msg = "`run_test_in_graph_and_eager_modes` only supports test methods.*" + with self.assertRaisesRegexp(ValueError, msg): + @test_util.run_in_graph_and_eager_modes() + class Foo(object): + pass + del Foo # Make pylint unused happy. + + def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self): + modes = [] + def _test(self): + if not context.executing_eagerly(): + self.skipTest("Skipping in graph mode") + modes.append("eager" if context.executing_eagerly() else "graph") + test_util.run_in_graph_and_eager_modes(_test)(self) + self.assertEqual(modes, ["eager"]) + + def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self): + modes = [] + def _test(self): + if context.executing_eagerly(): + self.skipTest("Skipping in eager mode") + modes.append("eager" if context.executing_eagerly() else "graph") + test_util.run_in_graph_and_eager_modes(_test)(self) + self.assertEqual(modes, ["graph"]) + + def test_run_in_graph_and_eager_modes_setup_in_same_mode(self): + modes = [] + mode_name = lambda: "eager" if context.executing_eagerly() else "graph" + + class ExampleTest(test_util.TensorFlowTestCase): + + def runTest(self): + pass + + def setUp(self): + modes.append("setup_" + mode_name()) + + @test_util.run_in_graph_and_eager_modes + def testBody(self): + modes.append("run_" + mode_name()) + + e = ExampleTest() + e.setUp() + e.testBody() + + self.assertEqual(modes[0:2], ["setup_graph", "run_graph"]) + self.assertEqual(modes[2:], ["setup_eager", "run_eager"]) + class GarbageCollectionTest(test_util.TensorFlowTestCase): @@ -619,7 +689,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): ReferenceCycleTest().test_has_no_cycle() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_no_leaked_tensor_decorator(self): class LeakedTensorTest(object): diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index af5d709f7e936e0438d5e03f60b44bc0017cb4b6..7d07c77c797668c858014cc31cf713050627d72f 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -158,6 +158,7 @@ def _get_config(layout_optimizer=True): layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, # do not remove duplicated nodes arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + rewrite_options.min_graph_nodes = -1 graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -1443,7 +1444,8 @@ class LayoutOptimizerTest(test.TestCase): def testGradient(self): meta_graph = _simple_metagraph() rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON, + min_graph_nodes=-1) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) @@ -1457,7 +1459,8 @@ class LayoutOptimizerTest(test.TestCase): def testDepthwise(self): meta_graph = _simple_metagraph(depthwise=True) rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON, + min_graph_nodes=-1) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index 7ed4b128e495c484d294ece40541427f21856cf1..b658edff2dffac9856432c575b9af0d2f0b1986b 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -76,7 +76,8 @@ class MemoryOptimizerSwapTest(test.TestCase): disable_model_pruning=True, meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, - memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) + memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL, + min_graph_nodes=-1) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), graph_size + 2) @@ -133,6 +134,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS), original_metagraph) self.assertGreater( @@ -158,6 +160,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS, # Checks that name scope "gradients/" also match sub-scope. @@ -297,6 +300,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): if 'Recomputed/' in node.name])) rewritten_graph_def = tf_optimizer.OptimizeGraph( rewriter_config_pb2.RewriterConfig( + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL), metagraph) self.assertEqual( diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 1c0f072dd32d38f048cfa48d38b45264951d095e..5a9afe725753749ea42d53382731ab14a3cf24f5 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -47,6 +47,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): rewriter_config = rewriter_config_pb2.RewriterConfig() rewriter_config.optimizers.append('constfold') + rewriter_config.min_graph_nodes = -1 graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) @@ -68,6 +69,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Optimize the graph. mg = meta_graph.create_meta_graph_def(graph=g) rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) # Check that the nodes referenced in various collections have been preserved @@ -109,6 +111,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Optimize the graph. mg = meta_graph.create_meta_graph_def(graph=g) rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) mg.graph_def.CopyFrom(optimized_graph) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index fe40c9fbed7c041ad6b6dc8cdb1c50b80f57a48f..8b6b28bc776fa500a93d0a3fb3bf91081ba86967 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -39,6 +39,7 @@ py_library( "datasets/imdb.py", "datasets/mnist.py", "datasets/reuters.py", + "estimator/__init__.py", "preprocessing/__init__.py", "preprocessing/image.py", "preprocessing/sequence.py", @@ -135,7 +136,7 @@ py_library( deps = [ ":backend", "//tensorflow/python/data", - "//tensorflow/python/training/checkpointable:data_structures_base", + "//tensorflow/python/training/checkpointable:data_structures", "@six_archive//:six", ], ) @@ -549,7 +550,7 @@ py_test( py_test( name = "gru_test", - size = "medium", + size = "large", srcs = ["layers/gru_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], # http://b/62136390 @@ -858,7 +859,7 @@ py_test( py_test( name = "backend_test", - size = "small", + size = "medium", srcs = ["backend_test.py"], srcs_version = "PY2AND3", deps = [ @@ -866,6 +867,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:util", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py index 3493069a5bf53ffbfe6447f2c1b3df7ac64cbf3a..198c66d9e184c82423e529540b92ad447b947cf8 100644 --- a/tensorflow/python/keras/__init__.py +++ b/tensorflow/python/keras/__init__.py @@ -27,6 +27,7 @@ from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks from tensorflow.python.keras import constraints from tensorflow.python.keras import datasets +from tensorflow.python.keras import estimator from tensorflow.python.keras import initializers from tensorflow.python.keras import layers from tensorflow.python.keras import losses diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py index f81f10719a31e2e79589d3b389049353c992091c..8df6d086111c4b179d2f0c7b5c1130a6cd95aaab 100644 --- a/tensorflow/python/keras/applications/densenet.py +++ b/tensorflow/python/keras/applications/densenet.py @@ -31,7 +31,6 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import AveragePooling2D from tensorflow.python.keras.layers import BatchNormalization @@ -44,6 +43,7 @@ from tensorflow.python.keras.layers import Input from tensorflow.python.keras.layers import MaxPooling2D from tensorflow.python.keras.layers import ZeroPadding2D from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.util.tf_export import tf_export @@ -238,7 +238,7 @@ def DenseNet(blocks, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py index fe1d0f2d4fb47f7ebab38f94afc8ace2f7b73cbc..14e3b6aa60dbfa7e62e04849d35633eed162a416 100644 --- a/tensorflow/python/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2.py @@ -31,7 +31,6 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import AveragePooling2D from tensorflow.python.keras.layers import BatchNormalization @@ -44,6 +43,7 @@ from tensorflow.python.keras.layers import Input from tensorflow.python.keras.layers import Lambda from tensorflow.python.keras.layers import MaxPooling2D from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -354,7 +354,7 @@ def InceptionResNetV2(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor` if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py index 857ad49dae9ef234fe7d8251601ee122de39c947..b5e28c781f71e67b8d835b50070b49add2d7930a 100644 --- a/tensorflow/python/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/applications/inception_v3.py @@ -37,7 +37,6 @@ from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import AveragePooling2D from tensorflow.python.keras.layers import BatchNormalization @@ -48,6 +47,7 @@ from tensorflow.python.keras.layers import GlobalMaxPooling2D from tensorflow.python.keras.layers import Input from tensorflow.python.keras.layers import MaxPooling2D from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -375,7 +375,7 @@ def InceptionV3(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py index 9d845be0d5b1ab06dd8a41bc04f75ae7b5f00789..e56c695a288026d12de6bc0bdb65706c71eefe14 100644 --- a/tensorflow/python/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -78,8 +78,7 @@ from tensorflow.python.keras import regularizers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine.network import get_source_inputs +from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import BatchNormalization from tensorflow.python.keras.layers import Conv2D @@ -92,6 +91,7 @@ from tensorflow.python.keras.layers import Reshape from tensorflow.python.keras.layers import ZeroPadding2D from tensorflow.python.keras.models import Model from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -317,7 +317,7 @@ def MobileNet(input_shape=None, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py index b521bc673139403dcdecbba8e35b5bafec2d42bf..ff79b3a057b8fd6ab3b0edf652a5bede0e2d7b87 100644 --- a/tensorflow/python/keras/applications/nasnet.py +++ b/tensorflow/python/keras/applications/nasnet.py @@ -49,7 +49,6 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras.applications.inception_v3 import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import add from tensorflow.python.keras.layers import AveragePooling2D @@ -65,6 +64,7 @@ from tensorflow.python.keras.layers import MaxPooling2D from tensorflow.python.keras.layers import SeparableConv2D from tensorflow.python.keras.layers import ZeroPadding2D from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -290,7 +290,7 @@ def NASNet(input_shape=None, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input diff --git a/tensorflow/python/keras/applications/resnet50.py b/tensorflow/python/keras/applications/resnet50.py index 508550f445e39dcf2a249bc91aaee289abfe3d1f..6afc08681214c5dbb0577623d30e27e9988c6a57 100644 --- a/tensorflow/python/keras/applications/resnet50.py +++ b/tensorflow/python/keras/applications/resnet50.py @@ -34,7 +34,6 @@ from tensorflow.python.keras import layers from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import AveragePooling2D from tensorflow.python.keras.layers import BatchNormalization @@ -277,7 +276,7 @@ def ResNet50(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py index 659a6533e6772402663aee891ed90df792b12f09..cef0230da96ed4b9c992e57839ebb2071383e3b1 100644 --- a/tensorflow/python/keras/applications/vgg16.py +++ b/tensorflow/python/keras/applications/vgg16.py @@ -32,7 +32,6 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Conv2D from tensorflow.python.keras.layers import Dense from tensorflow.python.keras.layers import Flatten @@ -202,7 +201,7 @@ def VGG16(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py index 5e27ab8fb1fb99c65566cc4519798e3b8e0e1b0b..c4031f551003eda076380d1ae5208ee0876f5750 100644 --- a/tensorflow/python/keras/applications/vgg19.py +++ b/tensorflow/python/keras/applications/vgg19.py @@ -32,7 +32,6 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Conv2D from tensorflow.python.keras.layers import Dense from tensorflow.python.keras.layers import Flatten @@ -211,7 +210,7 @@ def VGG19(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py index e1be8a3c46e6eafa43405f1472a2f0292b73aa0c..01397cfac2563273ba1215003df1afab293b6b20 100644 --- a/tensorflow/python/keras/applications/xception.py +++ b/tensorflow/python/keras/applications/xception.py @@ -44,7 +44,6 @@ from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import BatchNormalization from tensorflow.python.keras.layers import Conv2D @@ -55,6 +54,7 @@ from tensorflow.python.keras.layers import Input from tensorflow.python.keras.layers import MaxPooling2D from tensorflow.python.keras.layers import SeparableConv2D from tensorflow.python.keras.models import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -302,7 +302,7 @@ def Xception(include_top=True, # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: - inputs = get_source_inputs(input_tensor) + inputs = layer_utils.get_source_inputs(input_tensor) else: inputs = img_input # Create model. diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 84821918bfe6160d7ee1b4556e00f533a07f5ebd..11f99c030f309dbd6393c37a03db6d8b804c4dc0 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function import collections +import itertools import json import os import weakref @@ -2880,7 +2881,10 @@ class Function(object): feed_arrays.append(tensor) # We need to do array conversion and type casting at this level, since # `callable_fn` only supports exact matches. - array_vals.append(np.asarray(value, dtype=tensor.dtype.base_dtype.name)) + tensor_type = dtypes_module.as_dtype(tensor.dtype) + array_vals.append(np.asarray(value, + dtype=tensor_type.as_numpy_dtype)) + if self.feed_dict: for key in sorted(self.feed_dict.keys()): array_vals.append( @@ -3157,10 +3161,16 @@ def rnn(step_function, array_ops.stack( [1, array_ops.shape(output)[1]])) output = array_ops.where(tiled_mask_t, output, states[0]) - new_states = [ - array_ops.where(tiled_mask_t, new_states[i], states[i]) - for i in range(len(states)) - ] + + masked_states = [] + for i in range(len(states)): + states_dim = array_ops.shape(new_states[i])[1] + stacked_states_dim = array_ops.stack([1, states_dim]) + tiled_mask = array_ops.tile(mask_t, stacked_states_dim) + masked_state = array_ops.where(tiled_mask, new_states[i], states[i]) + masked_states.append(masked_state) + new_states = masked_states + output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states) else: @@ -4242,58 +4252,115 @@ def pool3d(x, return x -def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): - """Apply 1D conv with un-shared weights. - - Arguments: - inputs: 3D tensor with shape: - (batch_size, steps, input_dim) - if data_format is "channels_last" or - (batch_size, input_dim, steps) - if data_format is "channels_first". - kernel: the unshared weight for convolution, - with shape (output_length, feature_dim, filters) - kernel_size: a tuple of a single integer, - specifying the length of the 1D convolution window - strides: a tuple of a single integer, - specifying the stride length of the convolution - data_format: the data format, channels_first or channels_last - - Returns: - the tensor after 1d conv with un-shared weights, with shape (batch_size, - output_length, filters) +def local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format=None): + """Apply N-D convolution with un-shared weights. + + Arguments: + inputs: (N+2)-D tensor with shape + (batch_size, channels_in, d_in1, ..., d_inN) + if data_format='channels_first', or + (batch_size, d_in1, ..., d_inN, channels_in) + if data_format='channels_last'. + kernel: the unshared weight for N-D convolution, + with shape (output_items, feature_dim, channels_out), where + feature_dim = np.prod(kernel_size) * channels_in, + output_items = np.prod(output_shape). + kernel_size: a tuple of N integers, specifying the + spatial dimensions of the N-D convolution window. + strides: a tuple of N integers, specifying the strides + of the convolution along the spatial dimensions. + output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial + dimensionality of the output. + data_format: string, "channels_first" or "channels_last". + + Returns: + An (N+2)-D tensor with shape: + (batch_size, channels_out) + output_shape + if data_format='channels_first', or: + (batch_size,) + output_shape + (channels_out,) + if data_format='channels_last'. Raises: - ValueError: if `data_format` is neither `channels_last` or - `channels_first`. + ValueError: if `data_format` is neither + `channels_last` nor `channels_first`. """ if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format: ' + str(data_format)) - stride = strides[0] kernel_shape = int_shape(kernel) - output_length = kernel_shape[0] feature_dim = kernel_shape[1] + channels_out = kernel_shape[-1] + ndims = len(output_shape) + spatial_dimensions = list(range(ndims)) xs = [] - for i in range(output_length): - slice_length = slice(i * stride, i * stride + kernel_size[0]) + output_axes_ticks = [range(axis_max) for axis_max in output_shape] + for position in itertools.product(*output_axes_ticks): + slices = [slice(None)] + if data_format == 'channels_first': - xs.append(reshape(inputs[:, :, slice_length], (1, -1, feature_dim))) - else: - xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim))) + slices.append(slice(None)) + + slices.extend([slice(position[d] * strides[d], + position[d] * strides[d] + kernel_size[d]) + for d in spatial_dimensions]) + + if data_format == 'channels_last': + slices.append(slice(None)) + + xs.append(reshape(inputs[slices], (1, -1, feature_dim))) x_aggregate = concatenate(xs, axis=0) - # Shape: `(output_length, batch_size, filters)`. output = batch_dot(x_aggregate, kernel) + output = reshape(output, output_shape + (-1, channels_out)) if data_format == 'channels_first': - output = permute_dimensions(output, (1, 2, 0)) + permutation = [ndims, ndims + 1] + spatial_dimensions else: - output = permute_dimensions(output, (1, 0, 2)) - return output + permutation = [ndims] + spatial_dimensions + [ndims + 1] + + return permute_dimensions(output, permutation) + + +def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): + """Apply 1D conv with un-shared weights. + + Arguments: + inputs: 3D tensor with shape: + (batch_size, steps, input_dim) + if data_format is "channels_last" or + (batch_size, input_dim, steps) + if data_format is "channels_first". + kernel: the unshared weight for convolution, + with shape (output_length, feature_dim, filters). + kernel_size: a tuple of a single integer, + specifying the length of the 1D convolution window. + strides: a tuple of a single integer, + specifying the stride length of the convolution. + data_format: the data format, channels_first or channels_last. + + Returns: + A 3d tensor with shape: + (batch_size, output_length, filters) + if data_format='channels_first' + or 3D tensor with shape: + (batch_size, filters, output_length) + if data_format='channels_last'. + """ + output_shape = (kernel.shape[0],) + return local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format) def local_conv2d(inputs, @@ -4306,64 +4373,34 @@ def local_conv2d(inputs, Arguments: inputs: 4D tensor with shape: - (batch_size, filters, new_rows, new_cols) - if data_format='channels_first' - or 4D tensor with shape: - (batch_size, new_rows, new_cols, filters) - if data_format='channels_last'. + (batch_size, filters, new_rows, new_cols) + if data_format='channels_first' + or 4D tensor with shape: + (batch_size, new_rows, new_cols, filters) + if data_format='channels_last'. kernel: the unshared weight for convolution, - with shape (output_items, feature_dim, filters) + with shape (output_items, feature_dim, filters). kernel_size: a tuple of 2 integers, specifying the - width and height of the 2D convolution window. + width and height of the 2D convolution window. strides: a tuple of 2 integers, specifying the strides - of the convolution along the width and height. - output_shape: a tuple with (output_row, output_col) - data_format: the data format, channels_first or channels_last + of the convolution along the width and height. + output_shape: a tuple with (output_row, output_col). + data_format: the data format, channels_first or channels_last. Returns: - A 4d tensor with shape: + A 4D tensor with shape: (batch_size, filters, new_rows, new_cols) if data_format='channels_first' or 4D tensor with shape: (batch_size, new_rows, new_cols, filters) if data_format='channels_last'. - - Raises: - ValueError: if `data_format` is neither - `channels_last` or `channels_first`. """ - if data_format is None: - data_format = image_data_format() - if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format: ' + str(data_format)) - - stride_row, stride_col = strides - output_row, output_col = output_shape - kernel_shape = int_shape(kernel) - feature_dim = kernel_shape[1] - filters = kernel_shape[2] - - xs = [] - for i in range(output_row): - for j in range(output_col): - slice_row = slice(i * stride_row, i * stride_row + kernel_size[0]) - slice_col = slice(j * stride_col, j * stride_col + kernel_size[1]) - if data_format == 'channels_first': - xs.append( - reshape(inputs[:, :, slice_row, slice_col], (1, -1, feature_dim))) - else: - xs.append( - reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim))) - - x_aggregate = concatenate(xs, axis=0) - output = batch_dot(x_aggregate, kernel) - output = reshape(output, (output_row, output_col, -1, filters)) - - if data_format == 'channels_first': - output = permute_dimensions(output, (2, 3, 0, 1)) - else: - output = permute_dimensions(output, (2, 0, 1, 3)) - return output + return local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format) @tf_export('keras.backend.bias_add') diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 53e30e0e4aeda1847c4f1867cb37b87841ad1ee7..0ddffa61a490d20bb1043346eeb41e68ca470125 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -17,10 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np import scipy.sparse from tensorflow.python import keras +from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -661,7 +663,7 @@ class BackendShapeOpsTest(test.TestCase): np_kwargs={'data_format': 'channels_first'}) -class BackendNNOpsTest(test.TestCase): +class BackendNNOpsTest(test.TestCase, parameterized.TestCase): def test_bias_add(self): with self.test_session(): @@ -810,52 +812,117 @@ class BackendNNOpsTest(test.TestCase): padding='same', data_format='channels_last') self.assertEqual(y.get_shape().as_list(), [10, 5, 5]) - def test_local_conv1d_channels_dim(self): - input_length = 5 - input_dim = 3 + def test_local_conv_channels_dim(self): + filters = 3 batch_size = 2 - inputs = np.random.normal(0, 1, (batch_size, input_dim, input_length)) - inputs_cf = keras.backend.variable(inputs) + for input_shape in [(3, 5), (2, 3, 5), (2, 5, 3, 4)]: + channels_in = input_shape[0] + input_spatial_shape = input_shape[1:] + dim = len(input_spatial_shape) - filters = 4 - for kernel_size in [(1,), (2,), (3,)]: - for strides in [(1,), (2,), (3,)]: - output_length = (input_length - kernel_size[0] - + strides[0]) // strides[0] + inputs = np.random.normal(0, 1, (batch_size,) + input_shape) + inputs_cf = keras.backend.variable(inputs) - kernel_shape = (output_length, kernel_size[0] * input_dim, filters) - kernel = np.random.normal(0, 1, (output_length, - input_dim, - kernel_size[0], - filters)) - kernel_cf = np.reshape(kernel, kernel_shape) - kernel_cf = keras.backend.variable(kernel_cf) + for kernel_size in [1, 2]: + for stride in [1, 2]: + kernel_sizes = (kernel_size,) * dim + strides = (stride,) * dim - conv_cf = keras.backend.local_conv1d(inputs_cf, + output_shape = tuple([(i - kernel_size + stride) // stride + for i in input_spatial_shape]) + + kernel_shape = (np.prod(output_shape), + np.prod(kernel_sizes) * channels_in, + filters) + + kernel = np.random.normal( + 0, + 1, + output_shape + (channels_in, np.prod(kernel_sizes), filters) + ) + + kernel_cf = np.reshape(kernel, kernel_shape) + kernel_cf = keras.backend.variable(kernel_cf) + + conv_cf = keras.backend.local_conv(inputs_cf, kernel_cf, - kernel_size, + kernel_sizes, strides, + output_shape, 'channels_first') - inputs_cl = np.transpose(inputs, (0, 2, 1)) - inputs_cl = keras.backend.variable(inputs_cl) + inputs_cl = np.transpose(inputs, [0, 2] + list(range(3, dim + 2)) + + [1]) + inputs_cl = keras.backend.variable(inputs_cl) - kernel_cl = np.reshape(np.transpose(kernel, (0, 2, 1, 3)), - kernel_shape) - kernel_cl = keras.backend.variable(kernel_cl) + kernel_cl = np.reshape( + np.transpose(kernel, list(range(dim)) + [dim + 1, dim, dim + 2]), + kernel_shape + ) + kernel_cl = keras.backend.variable(kernel_cl) - conv_cl = keras.backend.local_conv1d(inputs_cl, + conv_cl = keras.backend.local_conv(inputs_cl, kernel_cl, - kernel_size, + kernel_sizes, strides, + output_shape, 'channels_last') - with self.test_session(): - conv_cf = keras.backend.eval(conv_cf) - conv_cl = keras.backend.eval(conv_cl) + with self.test_session(): + conv_cf = keras.backend.eval(conv_cf) + conv_cl = keras.backend.eval(conv_cl) + + self.assertAllCloseAccordingToType( + conv_cf, + np.transpose(conv_cl, + [0, dim + 1] + list(range(1, dim + 1))), + atol=1e-5 + ) + + @parameterized.named_parameters( + ('local_conv1d', (5, 6), (3,), (1,), (3,)), + ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3))) + def test_local_conv_1d_and_2d(self, + input_shape, + kernel_sizes, + strides, + output_shape): + filters = 3 + batch_size = 2 + + inputs = np.random.normal(0, 1, (batch_size,) + input_shape) + inputs = keras.backend.variable(inputs) + + kernel = np.random.normal(0, 1, (np.prod(output_shape), + np.prod(kernel_sizes) * input_shape[-1], + filters)) + kernel = keras.backend.variable(kernel) + + local_conv = keras.backend.local_conv(inputs, + kernel, + kernel_sizes, + strides, + output_shape, + 'channels_last') + if len(output_shape) == 1: + local_conv_dim = keras.backend.local_conv1d(inputs, + kernel, + kernel_sizes, + strides, + 'channels_last') + else: + local_conv_dim = keras.backend.local_conv2d(inputs, + kernel, + kernel_sizes, + strides, + output_shape, + 'channels_last') - self.assertAllCloseAccordingToType(conv_cf, - np.transpose(conv_cl, (0, 2, 1))) + with self.test_session(): + local_conv = keras.backend.eval(local_conv) + local_conv_dim = keras.backend.eval(local_conv_dim) + + self.assertAllCloseAccordingToType(local_conv, local_conv_dim) def test_conv2d(self): val = np.random.random((10, 4, 10, 10)) @@ -1010,7 +1077,7 @@ class BackendNNOpsTest(test.TestCase): {'go_backwards': False, 'mask': mask, 'unroll': True}, ] with self.test_session(): - for (i, kwargs) in enumerate(kwargs_list): + for i, kwargs in enumerate(kwargs_list): last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, initial_states, **kwargs) @@ -1057,6 +1124,115 @@ class BackendNNOpsTest(test.TestCase): for b_s, b_u_s in zip(state_list[2], state_list[3]): self.assertAllClose(b_s, b_u_s, atol=1e-04) + def test_rnn_additional_states(self): + # implement a simple RNN + num_samples = 4 + input_dim = 5 + output_dim = 3 + timesteps = 6 + + input_val = np.random.random( + (num_samples, timesteps, input_dim)).astype(np.float32) + init_state_val = np.random.random( + (num_samples, output_dim)).astype(np.float32) + w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32) + w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32) + np_mask = np.random.randint(2, size=(num_samples, timesteps)) + + def rnn_step_fn(): + w_i = keras.backend.variable(w_i_val) + w_o = keras.backend.variable(w_o_val) + + def step_function(x, states): + assert len(states) == 2 + prev_output = states[0] + output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o) + return output, [output, + keras.backend.concatenate([output, output], axis=-1)] + + return step_function + + # test default setup + last_output_list = [[], [], [], [], [], []] + outputs_list = [[], [], [], [], [], []] + state_list = [[], [], [], [], [], []] + additional_state_list = [[], [], [], [], [], []] + + rnn_fn = rnn_step_fn() + inputs = keras.backend.variable(input_val) + initial_states = [keras.backend.variable(init_state_val), + np.concatenate([init_state_val, init_state_val], axis=-1)] + mask = keras.backend.variable(np_mask) + + kwargs_list = [ + {'go_backwards': False, 'mask': None}, + {'go_backwards': False, 'mask': None, 'unroll': True}, + {'go_backwards': True, 'mask': None}, + {'go_backwards': True, 'mask': None, 'unroll': True}, + {'go_backwards': False, 'mask': mask}, + {'go_backwards': False, 'mask': mask, 'unroll': True}, + ] + with self.test_session(): + for i, kwargs in enumerate(kwargs_list): + last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, + initial_states, + **kwargs) + # check static shape inference + self.assertEqual(last_output.get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(outputs.get_shape().as_list(), + [num_samples, timesteps, output_dim]) + # for state in new_states: + # self.assertEquals(state.get_shape().as_list(), + # [num_samples, output_dim]) + self.assertEqual(new_states[0].get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(new_states[1].get_shape().as_list(), + [num_samples, 2 * output_dim]) + + last_output_list[i].append(keras.backend.eval(last_output)) + outputs_list[i].append(keras.backend.eval(outputs)) + self.assertEqual(len(new_states), 2) + state_list[i].append(keras.backend.eval(new_states[0])) + additional_state_list[i].append(keras.backend.eval(new_states[1])) + + def assert_list_pairwise(z_list, atol=1e-05): + for (z1, z2) in zip(z_list[1:], z_list[:-1]): + self.assertAllClose(z1, z2, atol=atol) + + assert_list_pairwise(last_output_list[0], atol=1e-04) + assert_list_pairwise(outputs_list[0], atol=1e-04) + assert_list_pairwise(state_list[0], atol=1e-04) + assert_list_pairwise(additional_state_list[0], atol=1e-04) + assert_list_pairwise(last_output_list[2], atol=1e-04) + assert_list_pairwise(outputs_list[2], atol=1e-04) + assert_list_pairwise(state_list[2], atol=1e-04) + assert_list_pairwise(additional_state_list[2], atol=1e-04) + + for l, u_l in zip(last_output_list[0], last_output_list[1]): + self.assertAllClose(l, u_l, atol=1e-04) + + for o, u_o in zip(outputs_list[0], outputs_list[1]): + self.assertAllClose(o, u_o, atol=1e-04) + + for s, u_s in zip(state_list[0], state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[0], additional_state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]): + self.assertAllClose(b_l, b_u_l, atol=1e-04) + + for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]): + self.assertAllClose(b_o, b_u_o, atol=1e-04) + + for b_s, b_u_s in zip(state_list[2], state_list[3]): + self.assertAllClose(b_s, b_u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[2], additional_state_list[3]): + self.assertAllClose(s, u_s, atol=1e-04) + def test_normalize_batch_in_training(self): val = np.random.random((10, 3, 10, 10)) x = keras.backend.variable(val) @@ -1212,6 +1388,13 @@ class TestRandomOps(test.TestCase): self.assertAllClose(np.max(y), 2., atol=0.1) self.assertAllClose(np.min(y), -2., atol=0.1) + def test_string_input(self): + seq = keras.Sequential([ + keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string), + keras.layers.Lambda(lambda x: x[0]) + ]) + preds = seq.predict([['tensorflow eager']]) + self.assertEqual(preds.shape, (1,)) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 9f91368e5bd772b47ac951a600f458126c1e12a6..00a9c479fb2a2414698375c304539d509829fc44 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -496,6 +496,9 @@ class EarlyStopping(Callback): monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. + baseline: baseline value for the monitored quantity. + Training will stop if the model doesn't show improvement over the + baseline. """ def __init__(self, @@ -503,13 +506,15 @@ class EarlyStopping(Callback): min_delta=0, patience=0, verbose=0, - mode='auto'): + mode='auto', + baseline=None): super(EarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.verbose = verbose - self.min_delta = min_delta + self.baseline = baseline + self.min_delta = abs(min_delta) self.wait = 0 self.stopped_epoch = 0 @@ -537,7 +542,10 @@ class EarlyStopping(Callback): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 - self.best = np.Inf if self.monitor_op == np.less else -np.Inf + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 5062a26580ddb10011fd04f9a6e75ee6d2adbc68..92d66c95f6b3f184c8ead3301e4c31dfacba5333 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -273,16 +273,43 @@ class KerasCallbacksTest(test.TestCase): 1, activation='sigmoid'),)) model.compile( optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) - stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) weights = model.get_weights() + stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience # This should allow training to go for at least `patience` epochs model.set_weights(weights) hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) - assert len(hist.epoch) >= patience + assert len(hist.epoch) >= patience + + def test_EarlyStopping_with_baseline(self): + with self.test_session(): + np.random.seed(1337) + baseline = 0.5 + (data, labels), _ = testing_utils.get_test_data( + train_samples=100, + test_samples=50, + input_shape=(1,), + num_classes=NUM_CLASSES) + model = keras.models.Sequential((keras.layers.Dense( + 1, input_dim=1, activation='relu'), keras.layers.Dense( + 1, activation='sigmoid'),)) + model.compile( + optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) + + stopper = keras.callbacks.EarlyStopping(monitor='acc', + baseline=baseline) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) + assert len(hist.epoch) == 1 + + patience = 3 + stopper = keras.callbacks.EarlyStopping(monitor='acc', + patience=patience, + baseline=baseline) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) + assert len(hist.epoch) >= patience def test_RemoteMonitor(self): if requests is None: diff --git a/tensorflow/python/keras/datasets/boston_housing.py b/tensorflow/python/keras/datasets/boston_housing.py index 4c4cab8c0865098ebed1a7fbe29246ef51bb9833..eeb7cbc44a72a5c624f8d1d1d9dbfab1fcd1b225 100644 --- a/tensorflow/python/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/datasets/boston_housing.py @@ -45,10 +45,9 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113): origin=origin_folder + 'boston_housing.npz', file_hash= 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5') - f = np.load(path) - x = f['x'] - y = f['y'] - f.close() + with np.load(path) as f: + x = f['x'] + y = f['y'] np.random.seed(seed) indices = np.arange(len(x)) diff --git a/tensorflow/python/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py index 03564accc74507713d198b8ba1ed8c08bd597e8d..a96b581960f3d5f60994fe92a1424e793d7e39c7 100644 --- a/tensorflow/python/keras/datasets/mnist.py +++ b/tensorflow/python/keras/datasets/mnist.py @@ -47,8 +47,8 @@ def load_data(path='mnist.npz'): path, origin=origin_folder + 'mnist.npz', file_hash='8a61469f7ea1b51cbae51d4f78837e45') - f = np.load(path) - x_train, y_train = f['x_train'], f['y_train'] - x_test, y_test = f['x_test'], f['y_test'] - f.close() - return (x_train, y_train), (x_test, y_test) + with np.load(path) as f: + x_train, y_train = f['x_train'], f['y_train'] + x_test, y_test = f['x_test'], f['y_test'] + + return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/datasets/reuters.py b/tensorflow/python/keras/datasets/reuters.py index 2120b4b2421c652c9587a2e644bf008c3ece3980..cb796bb06cf09157cc510b55e3981d518fd8b433 100644 --- a/tensorflow/python/keras/datasets/reuters.py +++ b/tensorflow/python/keras/datasets/reuters.py @@ -130,7 +130,5 @@ def get_word_index(path='reuters_word_index.json'): path, origin=origin_folder + 'reuters_word_index.json', file_hash='4d44cc38712099c9e383dc6e5f11a921') - f = open(path) - data = json.load(f) - f.close() - return data + with open(path) as f: + return json.load(f) diff --git a/tensorflow/python/keras/engine/__init__.py b/tensorflow/python/keras/engine/__init__.py index ec7c0831992b2691c442bbd30445dbff8dba662f..26aed34766f9e1e2094db7a4c8b66ff057dacc4b 100644 --- a/tensorflow/python/keras/engine/__init__.py +++ b/tensorflow/python/keras/engine/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# TODO(fchollet): Remove hourglass imports once external code is done importing +# non-public APIs. from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.engine.network import Network -from tensorflow.python.keras.engine.training import Model +from tensorflow.python.keras.utils.layer_utils import get_source_inputs del absolute_import del division diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 427efaaf11ecd964d72b3e34233920d2cdfeaeeb..aa84eaa8abba9cf9004cbcb15ce80370521b4f65 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -43,7 +43,8 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import data_structures_base +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -368,7 +369,7 @@ class Network(base_layer.Layer): if isinstance(value, ( base_layer.Layer, Network, - data_structures_base.CheckpointableDataStructureBase)): + data_structures.CheckpointableDataStructure)): try: is_graph_network = self._is_graph_network except AttributeError: @@ -527,6 +528,28 @@ class Network(base_layer.Layer): return layer raise ValueError('No such layer: ' + name) + @property + def _unfiltered_updates(self): + if context.executing_eagerly(): + return [] + updates = [] + for layer in self.layers: + if isinstance(layer, Network): + updates += layer._unfiltered_updates + else: + updates += layer.updates + return updates + + @property + def _unfiltered_losses(self): + losses = [] + for layer in self.layers: + if isinstance(layer, Network): + losses += layer._unfiltered_losses + else: + losses += layer.losses + return losses + @property def updates(self): """Retrieves the network's updates. @@ -536,6 +559,8 @@ class Network(base_layer.Layer): (e.g. will not include updates that were created by layers of this model outside of the model). + When the network has no registered inputs, all updates are returned. + Effectively, `network.updates` behaves like `layer.updates`. Concrete example: @@ -581,22 +606,20 @@ class Network(base_layer.Layer): if not self.trainable and not self.stateful: return [] - updates = [] - for layer in self.layers: - updates += layer.updates + updates = self._unfiltered_updates # `updates` might contain irrelevant updates, so it needs to be filtered # with respect to inputs the model has been called on. - if self.inputs: - relevant_inputs = self.inputs[:] - else: - relevant_inputs = [] - for i in range(1, len(self._inbound_nodes)): + relevant_inputs = [] + for i in range(0, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): relevant_inputs += inputs else: relevant_inputs.append(inputs) + if not relevant_inputs: + return updates + reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates) relevant_conditional_updates = [x for x in updates if x in reachable] unconditional_updates = [ @@ -615,25 +638,25 @@ class Network(base_layer.Layer): (e.g. will not include losses that depend on tensors that aren't inputs to this model). + When the network has no registered inputs, all losses are returned. + Returns: A list of loss tensors. """ - losses = [] - for layer in self.layers: - losses += layer.losses + losses = self._unfiltered_losses if context.executing_eagerly(): return losses - if self.inputs: - relevant_inputs = self.inputs[:] - else: - relevant_inputs = [] - for i in range(1, len(self._inbound_nodes)): + relevant_inputs = [] + for i in range(0, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): relevant_inputs += inputs else: relevant_inputs.append(inputs) + if not relevant_inputs: + return losses + reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses) relevant_conditional_losses = [x for x in losses if x in reachable] unconditional_losses = [ @@ -643,14 +666,14 @@ class Network(base_layer.Layer): @property def trainable_weights(self): - return layer_utils.gather_trainable_weights( + return checkpointable_layer_utils.gather_trainable_weights( trainable=self.trainable, sub_layers=self.layers, extra_variables=self._extra_variables) @property def non_trainable_weights(self): - return layer_utils.gather_non_trainable_weights( + return checkpointable_layer_utils.gather_non_trainable_weights( trainable=self.trainable, sub_layers=self.layers, extra_variables=self._extra_variables) @@ -1496,47 +1519,6 @@ class Network(base_layer.Layer): print_fn=print_fn) -def get_source_inputs(tensor, layer=None, node_index=None): - """Returns the list of input tensors necessary to compute `tensor`. - - Output will always be a list of tensors - (potentially with 1 element). - - Arguments: - tensor: The tensor to start from. - layer: Origin layer of the tensor. Will be - determined via tensor._keras_history if not provided. - node_index: Origin node index of the tensor. - - Returns: - List of input tensors. - """ - if not hasattr(tensor, '_keras_history'): - return tensor - - if layer is None or node_index: - layer, node_index, _ = tensor._keras_history - if not layer._inbound_nodes: - return [tensor] - else: - node = layer._inbound_nodes[node_index] - if not node.inbound_layers: - # Reached an Input layer, stop recursion. - return node.input_tensors - else: - source_tensors = [] - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - previous_sources = get_source_inputs(x, layer, node_index) - # Avoid input redundancy. - for x in previous_sources: - if x not in source_tensors: - source_tensors.append(x) - return source_tensors - - def _is_hdf5_filepath(filepath): return filepath.endswith('.h5') or filepath.endswith('.keras') diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py index b9a2e1f25f637dc8017f751bbdd400c1e5c9dd44..d5ccd44604b6b84ea0ceb4fa1c270b2c7dddc147 100644 --- a/tensorflow/python/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -351,7 +351,10 @@ def preprocess_weights_for_loading(layer, weights, original_keras_version=None, original_backend=None): - """Converts layers weights from Keras 1 format to Keras 2. + """Preprocess layer weights between different Keras formats. + + Converts layers weights from Keras 1 format to Keras 2 and also weights of + CuDNN layers in Keras 2. Arguments: layer: Layer instance. @@ -363,7 +366,18 @@ def preprocess_weights_for_loading(layer, Returns: A list of weights values (Numpy arrays). """ - if layer.__class__.__name__ == 'Bidirectional': + def convert_nested_bidirectional(weights): + """Converts layers nested in `Bidirectional` wrapper. + + This function uses `preprocess_weights_for_loading()` for converting + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ num_weights_per_layer = len(weights) // 2 forward_weights = preprocess_weights_for_loading( layer.forward_layer, weights[:num_weights_per_layer], @@ -371,7 +385,69 @@ def preprocess_weights_for_loading(layer, backward_weights = preprocess_weights_for_loading( layer.backward_layer, weights[num_weights_per_layer:], original_keras_version, original_backend) - weights = forward_weights + backward_weights + return forward_weights + backward_weights + + def convert_nested_time_distributed(weights): + """Converts layers nested in `TimeDistributed` wrapper. + + This function uses `preprocess_weights_for_loading()` for converting nested + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ + return preprocess_weights_for_loading( + layer.layer, weights, original_keras_version, original_backend) + + def convert_nested_model(weights): + """Converts layers nested in `Model` or `Sequential`. + + This function uses `preprocess_weights_for_loading()` for converting nested + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ + new_weights = [] + # trainable weights + for sublayer in layer.layers: + num_weights = len(sublayer.trainable_weights) + if num_weights > 0: + new_weights.extend(preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + + # non-trainable weights + for sublayer in layer.layers: + num_weights = len([l for l in sublayer.weights + if l not in sublayer.trainable_weights]) + if num_weights > 0: + new_weights.extend(preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + return new_weights + + # Convert layers nested in Bidirectional/Model/Sequential. + # Both transformation should be ran for both Keras 1->2 conversion + # and for conversion of CuDNN layers. + if layer.__class__.__name__ == 'Bidirectional': + weights = convert_nested_bidirectional(weights) + if layer.__class__.__name__ == 'TimeDistributed': + weights = convert_nested_time_distributed(weights) + elif layer.__class__.__name__ in ['Model', 'Sequential']: + weights = convert_nested_model(weights) if original_keras_version == '1': if layer.__class__.__name__ == 'TimeDistributed': @@ -446,35 +522,6 @@ def preprocess_weights_for_loading(layer, recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) weights = [kernel, recurrent_kernel, bias] - if layer.__class__.__name__ in ['Model', 'Sequential']: - new_weights = [] - # trainable weights - for sublayer in layer.layers: - num_weights = len(sublayer.trainable_weights) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - - # non-trainable weights - for sublayer in layer.layers: - num_weights = len([ - l for l in sublayer.weights if l not in sublayer.trainable_weights - ]) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - weights = new_weights - conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] if layer.__class__.__name__ in conv_layers: if original_backend == 'theano': @@ -486,6 +533,7 @@ def preprocess_weights_for_loading(layer, if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + # convert CuDNN layers return _convert_rnn_weights(layer, weights) @@ -624,7 +672,7 @@ def _convert_rnn_weights(layer, weights): kernels = transform_kernels(weights[0], transpose_input(from_cudnn), n_gates) recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) - biases = weights[2].reshape((2, -1) if from_cudnn else -1) + biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1) return [kernels, recurrent_kernels, biases] if bias_shape == (2 * units * n_gates,): @@ -806,7 +854,16 @@ def load_weights_from_hdf5_group_by_name(f, layers): str(len(weight_values)) + ' element(s).') # Set values. for i in range(len(weight_values)): - weight_value_tuples.append((symbolic_weights[i], weight_values[i])) + if K.int_shape(symbolic_weights[i]) != weight_values[i].shape: + raise ValueError('Layer #' + str(k) +' (named "' + layer.name + + '"), weight ' + str(symbolic_weights[i]) + + ' has shape {}'.format(K.int_shape( + symbolic_weights[i])) + + ', but the saved weight has shape ' + + str(weight_values[i].shape) + '.') + + else: + weight_value_tuples.append((symbolic_weights[i], weight_values[i])) K.batch_set_value(weight_value_tuples) diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 7e82db028b8db01d6fe2b693e2087cbdfac55314..030328f2a66f0ec406ac271aecfbf2dbebf22f5f 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import os import shutil import tempfile - from absl.testing import parameterized import numpy as np @@ -31,6 +30,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import saving from tensorflow.python.keras.engine import training from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops @@ -248,6 +248,82 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): self.assertAllClose(y, ref_y) + def test_sequential_weight_loading_group_name_with_incorrect_length(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, use_bias=False, + input_dim=input_dim, name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named \"d1\"\) expects 1 ' + r'weight\(s\), but the saved weights have 2 ' + r'element\(s\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + + def test_sequential_weight_loading_group_name_with_incorrect_shape(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim, + name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named "d1"\), weight ' + r' has ' + r'shape \(3, 10\), but the saved weight has ' + r'shape \(3, 5\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + class TestWholeModelSaving(test.TestCase): @@ -587,7 +663,7 @@ class SubclassedModel(training.Model): class TestWeightSavingAndLoadingTFFormat(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensorflow_format_overwrite(self): with self.test_session() as session: model = SubclassedModel() @@ -676,7 +752,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): restore_on_create_y = self.evaluate(restore_on_create_y_tensor) self.assertAllClose(ref_y, restore_on_create_y) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model(self): def _make_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -686,7 +762,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self._weight_loading_test_template(_make_graph_model) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_subclassed_model(self): self._weight_loading_test_template(SubclassedModel) @@ -720,7 +796,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): y = self.evaluate(model(x)) self.assertAllClose(ref_y, y) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model_added_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -740,7 +816,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): _save_graph_model, _restore_graph_model, _restore_init_fn) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model_added_no_weight_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -761,7 +837,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): _save_graph_model, _restore_graph_model, _restore_init_fn) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_subclassed_model_added_layer(self): class SubclassedModelRestore(training.Model): diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 3ca8fdd3260d964442c18bc30c3925f252e8a304..cd76f08a32505a0f408edcd129e44b39099646f3 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -24,10 +24,10 @@ import copy from tensorflow.python.keras import backend as K from tensorflow.python.keras import layers as layer_module from tensorflow.python.keras.engine import base_layer -from tensorflow.python.keras.engine import network from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.engine.training import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -146,8 +146,6 @@ class Sequential(Model): first_layer = layer.layers[0] while isinstance(first_layer, (Model, Sequential)): first_layer = first_layer.layers[0] - batch_shape = first_layer._batch_input_shape - dtype = first_layer.dtype if hasattr(first_layer, '_batch_input_shape'): batch_shape = first_layer._batch_input_shape @@ -179,7 +177,7 @@ class Sequential(Model): 'use the functional API.') self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] - self.inputs = network.get_source_inputs(self.outputs[0]) + self.inputs = layer_utils.get_source_inputs(self.outputs[0]) elif self.outputs: output_tensor = layer(self.outputs[0]) if isinstance(output_tensor, list): diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index cdaf9162de8c5318c7e092077d409f05a7edc717..0f54e29cee38bd12d691b03ae98d3e578b7ff907 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -33,7 +33,7 @@ class TestSequential(test.TestCase): """Most Sequential model API tests are covered in `training_test.py`. """ - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_basic_methods(self): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_dim=2)) @@ -44,7 +44,7 @@ class TestSequential(test.TestCase): self.assertEqual(len(model.weights), 2 * 2) self.assertEqual(model.get_layer(name='dp').name, 'dp') - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_pop(self): num_hidden = 5 input_dim = 3 @@ -77,7 +77,7 @@ class TestSequential(test.TestCase): with self.assertRaises(TypeError): model.pop() - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_with_np_arrays(self): num_hidden = 5 input_dim = 3 @@ -102,7 +102,7 @@ class TestSequential(test.TestCase): [None, num_classes]) self.assertEqual(len(model.weights), 2 * 2) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_with_dataset_iterators(self): if not context.executing_eagerly(): # TODO(psv/fchollet): Add support for this use case in graph mode. @@ -136,7 +136,7 @@ class TestSequential(test.TestCase): [None, num_classes]) self.assertEqual(len(model.weights), 2 * 2) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_invalid_use_cases(self): # Added objects must be layer instances with self.assertRaises(TypeError): @@ -160,7 +160,7 @@ class TestSequential(test.TestCase): model.add(keras.layers.Dense(1, input_dim=1)) model.add(MyLayer()) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_nested_sequential_trainability(self): input_dim = 20 num_units = 10 diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 183e26e8bf813ec0a8c84920a93dcb79a291ca9d..3eb69bd7f3d42f5cd8d6cc6d2d32cc9eb808d9a4 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import input_layer as input_layer_lib +from tensorflow.python.keras.engine import network as network_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -62,7 +64,7 @@ class TopologyConstructionTest(test.TestCase): inputs=True) return inputs + 1 - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = MyLayer() _ = layer.apply(x1) @@ -70,7 +72,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_updates_for(x1)), 1) self.assertEqual(len(layer.get_updates_for(None)), 1) - x2 = keras.Input(shape=(1,)) + x2 = input_layer_lib.Input(shape=(1,)) y2 = layer.apply(x2) self.assertEqual(len(layer.updates), 3) @@ -78,17 +80,17 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_updates_for(x2)), 1) self.assertEqual(len(layer.get_updates_for(None)), 1) - network = keras.engine.Network(x2, y2) + network = network_lib.Network(x2, y2) self.assertEqual(len(network.updates), 2) self.assertEqual(len(network.get_updates_for(x1)), 0) self.assertEqual(len(network.get_updates_for(x2)), 1) self.assertEqual(len(network.get_updates_for(None)), 1) - x3 = keras.Input(shape=(1,)) + x3 = input_layer_lib.Input(shape=(1,)) _ = layer.apply(x3) self.assertEqual(len(network.updates), 2) - x4 = keras.Input(shape=(1,)) + x4 = input_layer_lib.Input(shape=(1,)) _ = network(x4) self.assertEqual(len(network.updates), 3) self.assertEqual(len(network.get_updates_for(x2)), 1) @@ -104,7 +106,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(network.get_updates_for(x4)), 2) def test_get_updates_bn(self): - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = keras.layers.BatchNormalization() _ = layer.apply(x1) @@ -134,7 +136,7 @@ class TopologyConstructionTest(test.TestCase): inputs=True) return inputs + 1 - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = MyLayer() _ = layer.apply(x1) @@ -142,7 +144,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x1)), 1) self.assertEqual(len(layer.get_losses_for(None)), 1) - x2 = keras.Input(shape=(1,)) + x2 = input_layer_lib.Input(shape=(1,)) y2 = layer.apply(x2) self.assertEqual(len(layer.losses), 3) @@ -150,17 +152,17 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x2)), 1) self.assertEqual(len(layer.get_losses_for(None)), 1) - network = keras.engine.Network(x2, y2) + network = network_lib.Network(x2, y2) self.assertEqual(len(network.losses), 2) self.assertEqual(len(network.get_losses_for(x1)), 0) self.assertEqual(len(network.get_losses_for(x2)), 1) self.assertEqual(len(network.get_losses_for(None)), 1) - x3 = keras.Input(shape=(1,)) + x3 = input_layer_lib.Input(shape=(1,)) _ = layer.apply(x3) self.assertEqual(len(network.losses), 2) - x4 = keras.Input(shape=(1,)) + x4 = input_layer_lib.Input(shape=(1,)) _ = network(x4) self.assertEqual(len(network.losses), 3) self.assertEqual(len(network.get_losses_for(x2)), 1) @@ -177,8 +179,8 @@ class TopologyConstructionTest(test.TestCase): def testTopologicalAttributes(self): # test layer attributes / methods related to cross-layer connectivity. - a = keras.Input(shape=(32,), name='input_a') - b = keras.Input(shape=(32,), name='input_b') + a = input_layer_lib.Input(shape=(32,), name='input_a') + b = input_layer_lib.Input(shape=(32,), name='input_b') # test input, output, input_shape, output_shape test_layer = keras.layers.Dense(16, name='test_layer') @@ -219,15 +221,15 @@ class TopologyConstructionTest(test.TestCase): _ = new_dense.input_shape with self.assertRaises(AttributeError): new_dense = keras.layers.Dense(16) - a = keras.Input(shape=(3, 32)) - a = keras.Input(shape=(5, 32)) + a = input_layer_lib.Input(shape=(3, 32)) + a = input_layer_lib.Input(shape=(5, 32)) a_2 = dense(a) b_2 = dense(b) _ = new_dense.input_shape with self.assertRaises(AttributeError): new_dense = keras.layers.Dense(16) - a = keras.Input(shape=(3, 32)) - a = keras.Input(shape=(5, 32)) + a = input_layer_lib.Input(shape=(3, 32)) + a = input_layer_lib.Input(shape=(5, 32)) a_2 = dense(a) b_2 = dense(b) _ = new_dense.output_shape @@ -239,7 +241,7 @@ class TopologyConstructionTest(test.TestCase): def call(self, inputs): return [inputs**2, inputs**3] - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) test_layer = PowersLayer() p1, p2 = test_layer(x) # pylint: disable=not-callable @@ -256,8 +258,8 @@ class TopologyConstructionTest(test.TestCase): assert len(inputs) == 2 return inputs[0] + inputs[1] - a = keras.Input(shape=(32,)) - b = keras.Input(shape=(32,)) + a = input_layer_lib.Input(shape=(32,)) + b = input_layer_lib.Input(shape=(32,)) test_layer = AddLayer() y = test_layer([a, b]) # pylint: disable=not-callable @@ -268,10 +270,10 @@ class TopologyConstructionTest(test.TestCase): def testBasicNetwork(self): # minimum viable network - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) dense = keras.layers.Dense(2) y = dense(x) - network = keras.engine.Network(x, y, name='dense_network') + network = network_lib.Network(x, y, name='dense_network') # test basic attributes self.assertEqual(network.name, 'dense_network') @@ -282,7 +284,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) # test callability on Input - x_2 = keras.Input(shape=(32,)) + x_2 = input_layer_lib.Input(shape=(32,)) y_2 = network(x_2) self.assertEqual(y_2.get_shape().as_list(), [None, 2]) @@ -506,7 +508,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)]) # test get_source_inputs - self.assertListEqual(keras.engine.network.get_source_inputs(c), [a, b]) + self.assertListEqual(keras.engine.get_source_inputs(c), [a, b]) # serialization / deserialization json_config = model.to_json() @@ -778,12 +780,12 @@ class TopologyConstructionTest(test.TestCase): self.evaluate(getattr(b, '_keras_mask'))) self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) else: - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) y = MaskedLayer()(x) # pylint: disable=not-callable - network = keras.engine.Network(x, y) + network = network_lib.Network(x, y) # test callability on Input - x_2 = keras.Input(shape=(32,)) + x_2 = input_layer_lib.Input(shape=(32,)) y_2 = network(x_2) self.assertEqual(y_2.get_shape().as_list(), [None, 32]) @@ -797,14 +799,14 @@ class TopologyConstructionTest(test.TestCase): def reg(x): return math_ops.reduce_sum(x) - net_a_input = keras.Input((2,)) + net_a_input = input_layer_lib.Input((2,)) net_a = net_a_input net_a = keras.layers.Dense(2, kernel_initializer='ones', use_bias=False, activity_regularizer=reg)(net_a) model_a = keras.Model([net_a_input], [net_a]) - net_b_input = keras.Input((2,)) + net_b_input = input_layer_lib.Input((2,)) net_b = model_a(net_b_input) model_b = keras.Model([net_b_input], [net_b]) @@ -817,7 +819,7 @@ class TopologyConstructionTest(test.TestCase): with self.test_session(): x_val = np.random.random((10, 5)) - x = keras.Input(shape=(5,)) + x = input_layer_lib.Input(shape=(5,)) a = keras.layers.Dense(5, name='A') b = keras.layers.Dense(5, name='B') output = a(b(a(b(x)))) @@ -837,7 +839,7 @@ class TopologyConstructionTest(test.TestCase): def test_layer_sharing_at_heterogenous_depth_with_concat(self): with self.test_session(): input_shape = (16, 9, 3) - input_layer = keras.Input(shape=input_shape) + input_layer = input_layer_lib.Input(shape=input_shape) a = keras.layers.Dense(3, name='dense_A') b = keras.layers.Dense(3, name='dense_B') @@ -924,7 +926,7 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testSimpleNetworkBuilding(self): - inputs = keras.engine.Input(shape=(32,)) + inputs = input_layer_lib.Input(shape=(32,)) if context.executing_eagerly(): self.assertIsInstance(inputs, base_layer.DeferredTensor) self.assertEqual(inputs.dtype.name, 'float32') @@ -937,8 +939,8 @@ class DeferredModeTest(test.TestCase): self.assertEqual(x.shape.as_list(), [None, 2]) outputs = keras.layers.Dense(4)(x) - network = keras.engine.Network(inputs, outputs) - self.assertIsInstance(network, keras.engine.Network) + network = network_lib.Network(inputs, outputs) + self.assertIsInstance(network, network_lib.Network) if context.executing_eagerly(): # It should be possible to call such a network on EagerTensors. @@ -949,8 +951,8 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testMultiIONetworkbuilding(self): - input_a = keras.engine.Input(shape=(32,)) - input_b = keras.engine.Input(shape=(16,)) + input_a = input_layer_lib.Input(shape=(32,)) + input_b = input_layer_lib.Input(shape=(16,)) a = keras.layers.Dense(16)(input_a) class AddLayer(keras.layers.Layer): @@ -964,7 +966,7 @@ class DeferredModeTest(test.TestCase): c = AddLayer()([a, input_b]) # pylint: disable=not-callable c = keras.layers.Dense(2)(c) - network = keras.engine.Network([input_a, input_b], [a, c]) + network = network_lib.Network([input_a, input_b], [a, c]) if context.executing_eagerly(): a_val = constant_op.constant( np.random.random((10, 32)).astype('float32')) diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 1571a7782aaae3c3be233c7d4fdd91456a96e3e3..bdb30351290644e2f7e8135c047ef6732054a08a 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -647,7 +647,7 @@ class LossWeightingTest(test.TestCase): class CorrectnessTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -668,7 +668,7 @@ class CorrectnessTest(test.TestCase): self.assertEqual( np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness(self): model = keras.Sequential() model.add(keras.layers.Dense(3, @@ -689,7 +689,7 @@ class CorrectnessTest(test.TestCase): outs = model.evaluate(x, y) self.assertEqual(outs[1], 0.) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness_with_iterator(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -712,7 +712,7 @@ class CorrectnessTest(test.TestCase): history = model.fit(iterator, epochs=1, steps_per_epoch=10) self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness_with_iterator(self): model = keras.Sequential() model.add( diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index a1ab72018957a26984f9b7b1ccba9a128a136866..d9e548f01f86fd96c3abd7b3cdaf5106653393fd 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1696,7 +1696,7 @@ class TestTrainingWithDataTensors(test.TestCase): model.train_on_batch([input_a_np, input_b_np], [output_a_np, output_b_np]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_metric_names_are_identical_in_graph_and_eager(self): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(3,), name='input_b') @@ -1723,7 +1723,7 @@ class TestTrainingWithDataTensors(test.TestCase): class TestTrainingWithDatasetIterators(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_training_and_eval_methods_on_iterators_single_io(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1813,7 +1813,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): ops.get_default_graph().finalize() model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_iterators_running_out_of_data(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1867,7 +1867,7 @@ class TestTrainingWithDataset(test.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_data=dataset, validation_steps=2) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_training_and_eval_methods_on_dataset(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b244beb5b58cf339a4687216b87418c88b953c17 --- /dev/null +++ b/tensorflow/python/keras/estimator/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras estimator API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.tf_export import tf_export + +# Keras has undeclared dependency on tensorflow/estimator:estimator_py. +# As long as you depend //third_party/py/tensorflow:tensorflow target +# everything will work as normal. + +try: + from tensorflow.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top + model_to_estimator = tf_export('keras.estimator.model_to_estimator')( + keras_lib.model_to_estimator) +except Exception: # pylint: disable=broad-except + + # pylint: disable=unused-argument + def stub_model_to_estimator(keras_model=None, + keras_model_path=None, + custom_objects=None, + model_dir=None, + config=None): + raise NotImplementedError( + 'tf.keras.estimator.model_to_estimator function not available in your ' + 'installation.') + # pylint: enable=unused-argument + + model_to_estimator = tf_export('keras.estimator.model_to_estimator')( + stub_model_to_estimator) + diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index ce0cdb2e1b34d1133f64965818b0a2bcef108d86..e3a686f45d92dde8ea90d496b3cb5099f6b84b58 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -20,15 +20,16 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras.engine import Input -from tensorflow.python.keras.engine import InputLayer -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU from tensorflow.python.keras.layers.advanced_activations import PReLU from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ReLU from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU from tensorflow.python.keras.layers.advanced_activations import Softmax diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 8ade3c317456a88181f6005c620953817463595b..eba10da6f3ce1367f4cb0180d16efdc5913fcddc 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -23,8 +23,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -278,3 +278,40 @@ class Softmax(Layer): @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): return input_shape + + +@tf_export('keras.layers.ReLU') +class ReLU(Layer): + """Rectified Linear Unit activation function. + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as the input. + + Arguments: + max_value: float >= 0. Maximum activation value. + """ + + def __init__(self, max_value=None, **kwargs): + super(ReLU, self).__init__(**kwargs) + self.support_masking = True + self.max_value = K.cast_to_floatx(max_value) + if self.max_value < 0.: + raise ValueError('max_value of Relu layer ' + 'cannot be negative value: ' + str(max_value)) + + def call(self, inputs): + return activations.relu(inputs, max_value=self.max_value) + + def get_config(self): + config = {'max_value': self.max_value} + base_config = super(ReLU, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py index 81c76db14cd3741687bf5e2bec66e5354e9f6312..9e1f15b1bc508d8be0a2c0190d07eb1c2bed95c4 100644 --- a/tensorflow/python/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/layers/advanced_activations_test.py @@ -62,6 +62,20 @@ class AdvancedActivationsTest(test.TestCase): kwargs={'axis': 1}, input_shape=(2, 3, 4)) + def test_relu(self): + with self.test_session(): + testing_utils.layer_test(keras.layers.ReLU, + kwargs={'max_value': 10}, + input_shape=(2, 3, 4)) + + def test_relu_with_invalid_arg(self): + with self.assertRaisesRegexp( + ValueError, 'max_value of Relu layer cannot be negative value: -10'): + with self.test_session(): + testing_utils.layer_test(keras.layers.ReLU, + kwargs={'max_value': -10}, + input_shape=(2, 3, 4)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 1c2a77d29731038040634126392e41c6eee76391..a57ac121ed7486a9beb64e6dd7ed3b132ca258df 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -26,8 +26,8 @@ from tensorflow.python.keras import backend from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer # imports for backwards namespace compatibility # pylint: disable=unused-import from tensorflow.python.keras.layers.pooling import AveragePooling1D @@ -1195,6 +1195,7 @@ class SeparableConv(Conv): dilation_rate=dilation_rate, activation=activations.get(activation), use_bias=use_bias, + bias_initializer=initializers.get(bias_initializer), bias_regularizer=regularizers.get(bias_regularizer), activity_regularizer=regularizers.get(activity_regularizer), bias_constraint=bias_constraint, diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index c731508b3c32d93895432fd5174c1f57557b10dc..84d794cada86b15755c28592d4c8093a4d3ef87e 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -26,8 +26,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.layers.recurrent import _generate_dropout_mask from tensorflow.python.keras.layers.recurrent import _standardize_args from tensorflow.python.keras.layers.recurrent import RNN diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py index 39988ba33ae7252ac6dd61ee7ee37b7335ffad24..f904744422a4b1296e8f5e8a34373fd0344dc643 100644 --- a/tensorflow/python/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -45,7 +45,7 @@ class Convolution1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv1d(self): kwargs = { 'filters': 2, @@ -117,7 +117,7 @@ class Conv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv2d(self): kwargs = { 'filters': 2, @@ -192,7 +192,7 @@ class Conv2DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv2dtranspose(self): kwargs = { 'filters': 2, @@ -258,7 +258,7 @@ class Conv3DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv3dtranspose(self): kwargs = { 'filters': 2, @@ -322,7 +322,7 @@ class SeparableConv1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_separable_conv1d(self): kwargs = { 'filters': 2, @@ -398,7 +398,7 @@ class SeparableConv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_separable_conv2d(self): kwargs = { 'filters': 2, @@ -477,7 +477,7 @@ class Conv3DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv3d(self): kwargs = { 'filters': 2, @@ -529,7 +529,7 @@ class Conv3DTest(test.TestCase): class ZeroPaddingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_1d(self): num_samples = 2 input_dim = 2 @@ -581,7 +581,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding1D(padding=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_2d(self): num_samples = 2 stack_size = 2 @@ -660,7 +660,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding2D(padding=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_3d(self): num_samples = 2 stack_size = 2 @@ -702,13 +702,13 @@ class ZeroPaddingTest(test.TestCase): class UpSamplingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_1d(self): with self.test_session(use_gpu=True): testing_utils.layer_test( keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_2d(self): num_samples = 2 stack_size = 2 @@ -758,7 +758,7 @@ class UpSamplingTest(test.TestCase): np.testing.assert_allclose(np_output, expected_out) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_3d(self): num_samples = 2 stack_size = 2 @@ -818,7 +818,7 @@ class UpSamplingTest(test.TestCase): class CroppingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_1d(self): num_samples = 2 time_length = 4 @@ -837,7 +837,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping1D(cropping=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_2d(self): num_samples = 2 stack_size = 2 @@ -905,7 +905,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping2D(cropping=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_3d(self): num_samples = 2 stack_size = 2 diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index f60064ed6363d36731795d08bb42e75398628283..2bf6229ccba808360e73a333bdec3dac624d81ce 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -33,8 +33,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index ff8af976b99376b037af81ed81707332ccf9937e..226403c5927ed22394b708178679d1efa11dd790 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -51,7 +51,7 @@ class CoreLayersTest(test.TestCase): dropout = keras.layers.Dropout(0.5) self.assertEqual(True, dropout.supports_masking) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_spatial_dropout(self): testing_utils.layer_test( keras.layers.SpatialDropout1D, @@ -78,7 +78,7 @@ class CoreLayersTest(test.TestCase): kwargs={'rate': 0.5, 'data_format': 'channels_first'}, input_shape=(2, 3, 4, 4, 5)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_activation(self): # with string argument testing_utils.layer_test( @@ -92,7 +92,7 @@ class CoreLayersTest(test.TestCase): kwargs={'activation': keras.backend.relu}, input_shape=(3, 2)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_reshape(self): testing_utils.layer_test( keras.layers.Reshape, @@ -114,12 +114,12 @@ class CoreLayersTest(test.TestCase): kwargs={'target_shape': (-1, 1)}, input_shape=(None, None, 2)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_permute(self): testing_utils.layer_test( keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_flatten(self): testing_utils.layer_test( keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4)) @@ -134,7 +134,7 @@ class CoreLayersTest(test.TestCase): np.transpose(inputs, (0, 2, 3, 1)), (-1, 5 * 5 * 3)) self.assertAllClose(outputs, target_outputs) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_repeat_vector(self): testing_utils.layer_test( keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2)) @@ -173,7 +173,7 @@ class CoreLayersTest(test.TestCase): config = ld.get_config() ld = keras.layers.Lambda.from_config(config) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dense(self): testing_utils.layer_test( keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2)) diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py index ad6594279d037c8dc0e1408955d2a2eebd51ce1d..cf2b0c476c7229a288f4b4f7b31de09388ade40f 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -25,7 +25,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec +from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras.layers.recurrent import RNN from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py index 9d186f8c586bd9f626e142a855be6d2cf00d7121..8fd970239f205031954c728474abdf10ea80e99e 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile from absl.testing import parameterized import numpy as np @@ -30,7 +32,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class CuDNNTest(test.TestCase, parameterized.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_cudnn_rnn_basics(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): @@ -58,7 +60,7 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): 'go_backwards': go_backwards}, input_shape=(num_samples, timesteps, input_size)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainability(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): @@ -217,27 +219,14 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): out5 = model.predict(np.ones((num_samples, timesteps))) self.assertNotEqual(out4.max(), out5.max()) - # TODO(psv): Add generic cross product helper function for parametrized tests. @parameterized.named_parameters( - ('cudnnlstm_to_lstm_unidirectional_impl_1', 'LSTM', False, False, 1), - ('cudnnlstm_to_lstm_bidirectional_impl_1', 'LSTM', False, True, 1), - ('lstm_to_cudnnlstm_unidirectional_impl_1', 'LSTM', True, False, 1), - ('lstm_to_cudnnlstm_bidirectional_impl_1', 'LSTM', True, True, 1), - ('cudnngru_to_gru_unidirectional_impl_1', 'GRU', False, False, 1), - ('cudnngru_to_gru_bidirectional_impl_1', 'GRU', False, True, 1), - ('gru_to_cudnngru_unidirectional_impl_1', 'GRU', True, False, 1), - ('gru_to_cudnngru_bidirectional_impl_1', 'GRU', True, True, 1), - ('cudnnlstm_to_lstm_unidirectional_impl_2', 'LSTM', False, False, 2), - ('cudnnlstm_to_lstm_bidirectional_impl_2', 'LSTM', False, True, 2), - ('lstm_to_cudnnlstm_unidirectional_impl_2', 'LSTM', True, False, 2), - ('lstm_to_cudnnlstm_bidirectional_impl_2', 'LSTM', True, True, 2), - ('cudnngru_to_gru_unidirectional_impl_2', 'GRU', False, False, 2), - ('cudnngru_to_gru_bidirectional_impl_2', 'GRU', False, True, 2), - ('gru_to_cudnngru_unidirectional_impl_2', 'GRU', True, False, 2), - ('gru_to_cudnngru_bidirectional_impl_2', 'GRU', True, True, 2), - ) + *testing_utils.generate_combinations_with_testcase_name( + rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False], + bidirectional=[True, False], implementation=[1, 2], + model_nest_level=[1, 2], model_type=['seq', 'func'])) def test_load_weights_between_noncudnn_rnn(self, rnn_type, to_cudnn, - bidirectional, implementation): + bidirectional, implementation, + model_nest_level, model_type): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): input_size = 10 @@ -261,14 +250,6 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): cudnn_rnn_layer_class = keras.layers.CuDNNGRU rnn_layer_kwargs['reset_after'] = True - def convert_weights(source_layer, target_layer): - weights = source_layer.get_weights() - weights = keras.engine.saving.preprocess_weights_for_loading( - target_layer, weights) - target_layer.set_weights(weights) - - input_layer = keras.layers.InputLayer(input_shape) - layer = rnn_layer_class(units, **rnn_layer_kwargs) if bidirectional: layer = keras.layers.Bidirectional(layer) @@ -277,18 +258,96 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): if bidirectional: cudnn_layer = keras.layers.Bidirectional(cudnn_layer) - model = keras.models.Sequential([input_layer, layer]) - cudnn_model = keras.models.Sequential([input_layer, cudnn_layer]) + model = self._make_nested_model(input_shape, layer, model_nest_level, + model_type) + cudnn_model = self._make_nested_model(input_shape, cudnn_layer, + model_nest_level, model_type) + + if to_cudnn: + self._convert_model_weights(model, cudnn_model) + else: + self._convert_model_weights(cudnn_model, model) + + self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs), + atol=1e-4) + + def _make_nested_model(self, input_shape, layer, level=1, model_type='func'): + # example: make_nested_seq_model((1,), Dense(10), level=2).summary() + def make_nested_seq_model(input_shape, layer, level=1): + model = layer + for i in range(1, level + 1): + layers = [keras.layers.InputLayer(input_shape), + model] if (i == 1) else [model] + model = keras.models.Sequential(layers) + return model + + # example: make_nested_func_model((1,), Dense(10), level=2).summary() + def make_nested_func_model(input_shape, layer, level=1): + model_input = keras.layers.Input(input_shape) + model = layer + for _ in range(level): + model = keras.models.Model(model_input, model(model_input)) + return model + + if model_type == 'func': + return make_nested_func_model(input_shape, layer, level) + elif model_type == 'seq': + return make_nested_seq_model(input_shape, layer, level) + + def _convert_model_weights(self, source_model, target_model): + _, fname = tempfile.mkstemp('.h5') + source_model.save_weights(fname) + target_model.load_weights(fname) + os.remove(fname) + + @parameterized.named_parameters( + *testing_utils.generate_combinations_with_testcase_name( + rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False])) + def test_load_weights_between_noncudnn_rnn_time_distributed(self, rnn_type, + to_cudnn): + # Similar test as test_load_weights_between_noncudnn_rnn() but has different + # rank of input due to usage of TimeDistributed. Issue: #10356. + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + input_size = 10 + steps = 6 + timesteps = 6 + input_shape = (timesteps, steps, input_size) + units = 2 + num_samples = 32 + inputs = np.random.random((num_samples, timesteps, steps, input_size)) + + rnn_layer_kwargs = { + 'recurrent_activation': 'sigmoid', + # ensure biases are non-zero and properly converted + 'bias_initializer': 'random_uniform', + } + if rnn_type == 'LSTM': + rnn_layer_class = keras.layers.LSTM + cudnn_rnn_layer_class = keras.layers.CuDNNLSTM + else: + rnn_layer_class = keras.layers.GRU + cudnn_rnn_layer_class = keras.layers.CuDNNGRU + rnn_layer_kwargs['reset_after'] = True + + layer = rnn_layer_class(units, **rnn_layer_kwargs) + layer = keras.layers.TimeDistributed(layer) + + cudnn_layer = cudnn_rnn_layer_class(units) + cudnn_layer = keras.layers.TimeDistributed(cudnn_layer) + + model = self._make_nested_model(input_shape, layer) + cudnn_model = self._make_nested_model(input_shape, cudnn_layer) if to_cudnn: - convert_weights(layer, cudnn_layer) + self._convert_model_weights(model, cudnn_model) else: - convert_weights(cudnn_layer, layer) + self._convert_model_weights(cudnn_model, model) - self.assertAllClose( - model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4) + self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs), + atol=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_cudnnrnn_bidirectional(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 25eeeee9529bcb52e608eeb9468c210eea8bd8be..910fff720f6312041a25922cf5c63dfa8f83ec76 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -22,7 +22,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py index 234434f7a0205c7dda80d308e4780cd761352d77..57f660b6d5a70b950918a3f6d75c87ecccf76f82 100644 --- a/tensorflow/python/keras/layers/gru_test.py +++ b/tensorflow/python/keras/layers/gru_test.py @@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class GRULayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_GRU(self): num_samples = 2 timesteps = 3 @@ -41,7 +41,7 @@ class GRULayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_GRU(self): num_samples = 2 timesteps = 3 @@ -55,7 +55,7 @@ class GRULayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_GRU(self): num_samples = 2 timesteps = 3 @@ -68,7 +68,7 @@ class GRULayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_GRU(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py index f222ea3083bad48094fbec7fe6750921f0233e35..0ebafe07cc45698200d0b1fa858a436c7a08820e 100644 --- a/tensorflow/python/keras/layers/local.py +++ b/tensorflow/python/keras/layers/local.py @@ -23,8 +23,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.util.tf_export import tf_export @@ -140,9 +140,9 @@ class LocallyConnected1D(Layer): if input_dim is None: raise ValueError('Axis 2 of input should be fully-defined. ' 'Found shape:', input_shape) - output_length = conv_utils.conv_output_length( + self.output_length = conv_utils.conv_output_length( input_length, self.kernel_size[0], self.padding, self.strides[0]) - self.kernel_shape = (output_length, self.kernel_size[0] * input_dim, + self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim, self.filters) self.kernel = self.add_weight( shape=self.kernel_shape, @@ -152,7 +152,7 @@ class LocallyConnected1D(Layer): constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - shape=(output_length, self.filters), + shape=(self.output_length, self.filters), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -182,12 +182,13 @@ class LocallyConnected1D(Layer): return (input_shape[0], length, self.filters) def call(self, inputs): - output = K.local_conv1d(inputs, self.kernel, self.kernel_size, - self.strides, self.data_format) + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_length,), self.data_format) + if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) - if self.activation is not None: - output = self.activation(output) + + output = self.activation(output) return output def get_config(self): @@ -400,9 +401,8 @@ class LocallyConnected2D(Layer): return (input_shape[0], rows, cols, self.filters) def call(self, inputs): - output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides, - (self.output_row, self.output_col), - self.data_format) + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_row, self.output_col), self.data_format) if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py index 8df3f6b7bd741ad0b698fe500f0ac72e73985421..9639e0251f5a56e4130b13c0185792fe11da2532 100644 --- a/tensorflow/python/keras/layers/local_test.py +++ b/tensorflow/python/keras/layers/local_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import test class LocallyConnectedLayersTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_1d(self): num_samples = 2 num_steps = 8 @@ -92,7 +92,7 @@ class LocallyConnectedLayersTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_2d(self): num_samples = 8 filters = 3 @@ -118,7 +118,7 @@ class LocallyConnectedLayersTest(test.TestCase): }, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_2d_channels_first(self): num_samples = 8 filters = 3 diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py index 87cb344bf82b73b6af9830a4428a5ba099135324..ae381f595565cf0d060320354cb32585c1067f72 100644 --- a/tensorflow/python/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/layers/lstm_test.py @@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class LSTMLayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_LSTM(self): num_samples = 2 timesteps = 3 @@ -56,7 +56,7 @@ class LSTMLayerTest(test.TestCase): outputs = model.layers[-1].output self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_LSTM(self): num_samples = 2 timesteps = 3 @@ -70,7 +70,7 @@ class LSTMLayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_LSTM(self): num_samples = 2 timesteps = 3 @@ -83,7 +83,7 @@ class LSTMLayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_LSTM(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py index 8a097cf7f57d06155f26e3099554e34a54186189..39bc98d039624d50788e1b7995dc5fba300a5276 100644 --- a/tensorflow/python/keras/layers/merge_test.py +++ b/tensorflow/python/keras/layers/merge_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import test class MergeLayersTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_add(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -76,7 +76,7 @@ class MergeLayersTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.add([i1]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_multiply(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -92,7 +92,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_average(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -106,7 +106,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_maximum(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -120,7 +120,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_minimum(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -134,7 +134,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_concatenate(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -169,7 +169,7 @@ class MergeLayersTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'called on a list'): keras.layers.concatenate([i1], axis=-1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_dot(self): i1 = keras.layers.Input(shape=(4,)) i2 = keras.layers.Input(shape=(4,)) @@ -215,7 +215,7 @@ class MergeLayersTest(test.TestCase): dot = keras.layers.Dot(1) dot.compute_output_shape(1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_subtract(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) diff --git a/tensorflow/python/keras/layers/noise.py b/tensorflow/python/keras/layers/noise.py index a895caa25b91702d92002f84fe44b5b5c3a8ca0c..cb7cee3ebc3ebd2413836b876f2aaf21985f1d9c 100644 --- a/tensorflow/python/keras/layers/noise.py +++ b/tensorflow/python/keras/layers/noise.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.keras import backend as K -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py index bde2185f03bd45c1c9fecbd6fe5544a17e9c04ef..aa2be62390b0dcf0656a533cba9bdbe9ceee09dd 100644 --- a/tensorflow/python/keras/layers/noise_test.py +++ b/tensorflow/python/keras/layers/noise_test.py @@ -40,7 +40,7 @@ class NoiseLayersTest(test.TestCase): kwargs={'rate': 0.5}, input_shape=(3, 2, 3)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_AlphaDropout(self): testing_utils.layer_test( keras.layers.AlphaDropout, diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 28cedec338246bbacd8b8cf40d83fc43ffdba0a2..d4c213eedd9eb3da0a3644540da29fa22a60f453 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -26,8 +26,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py index 10a82b285eff6f6b414e67441ceb88976ca2368f..912e8bd619db8b35a54853c0752382479567fd04 100644 --- a/tensorflow/python/keras/layers/pooling.py +++ b/tensorflow/python/keras/layers/pooling.py @@ -20,8 +20,8 @@ from __future__ import print_function from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import conv_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py index cbd58a22879975b7dbaab8290f59cee573b272cd..2cd9939e66ff869dac5058d2dd00d8d495e40f55 100644 --- a/tensorflow/python/keras/layers/pooling_test.py +++ b/tensorflow/python/keras/layers/pooling_test.py @@ -27,14 +27,14 @@ from tensorflow.python.platform import test class GlobalPoolingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_1d(self): testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D, input_shape=(3, 4, 5)) testing_utils.layer_test( keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_2d(self): testing_utils.layer_test( keras.layers.pooling.GlobalMaxPooling2D, @@ -53,7 +53,7 @@ class GlobalPoolingTest(test.TestCase): kwargs={'data_format': 'channels_last'}, input_shape=(3, 5, 6, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_3d(self): testing_utils.layer_test( keras.layers.pooling.GlobalMaxPooling3D, @@ -75,7 +75,7 @@ class GlobalPoolingTest(test.TestCase): class Pooling2DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_2d(self): pool_size = (3, 3) for strides in [(1, 1), (2, 2)]: @@ -88,7 +88,7 @@ class Pooling2DTest(test.TestCase): }, input_shape=(3, 5, 6, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_2d(self): testing_utils.layer_test( keras.layers.AveragePooling2D, @@ -122,7 +122,7 @@ class Pooling2DTest(test.TestCase): class Pooling3DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_3d(self): pool_size = (3, 3, 3) testing_utils.layer_test( @@ -141,7 +141,7 @@ class Pooling3DTest(test.TestCase): }, input_shape=(3, 4, 11, 12, 10)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_3d(self): pool_size = (3, 3, 3) testing_utils.layer_test( @@ -163,7 +163,7 @@ class Pooling3DTest(test.TestCase): class Pooling1DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_1d(self): for padding in ['valid', 'same']: for stride in [1, 2]: @@ -173,7 +173,7 @@ class Pooling1DTest(test.TestCase): 'padding': padding}, input_shape=(3, 5, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_1d(self): for padding in ['valid', 'same']: for stride in [1, 2]: diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 7e509fb45182653d938adfd679e204cc7ea1e900..32d25c5a650d3b66d944eee945cafa2d6f54d405 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -29,8 +29,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index be306c0af765dd79bcc2b7651d97957c1cf80519..7c45e08b5c48084cc57569a4d1102a0a7c5b29e1 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -20,8 +20,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras.engine import Input -from tensorflow.python.keras.engine import InputLayer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.layers.advanced_activations import * from tensorflow.python.keras.layers.convolutional import * from tensorflow.python.keras.layers.convolutional_recurrent import * diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py index 3d24b0d5045d9c264f32adedaa0e91cdc5cbb0cf..18fefbe84f6f46f2043c6586ecbc85ea76c55ea0 100644 --- a/tensorflow/python/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/layers/simplernn_test.py @@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class SimpleRNNLayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -41,7 +41,7 @@ class SimpleRNNLayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -55,7 +55,7 @@ class SimpleRNNLayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -68,7 +68,7 @@ class SimpleRNNLayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_SimpleRNN(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 7759561ef94c4a81552ef7b40ea71e49bbb743ae..e61acf8e771eb8de1c466ffa5e1c4c7f543f77ef 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -23,8 +23,8 @@ import copy from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend as K -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.layers.recurrent import _standardize_args from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils @@ -45,7 +45,9 @@ class Wrapper(Layer): """ def __init__(self, layer, **kwargs): + assert isinstance(layer, Layer) self.layer = layer + self._track_checkpointable(layer, name='layer') # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when # the inner layer has update ops that depend on its inputs (as opposed # to the inputs to the Wrapper layer). @@ -154,9 +156,16 @@ class TimeDistributed(Wrapper): Arguments: layer: a layer instance. + + Raises: + ValueError: If not initialized with a `Layer` instance. """ def __init__(self, layer, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `TimeDistributed` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) super(TimeDistributed, self).__init__(layer, **kwargs) self.supports_masking = True @@ -166,7 +175,10 @@ class TimeDistributed(Wrapper): self.input_spec = InputSpec(shape=input_shape) child_input_shape = [input_shape[0]] + input_shape[2:] if not self.layer.built: - self.layer.build(child_input_shape) + # The base layer class calls a conversion function on the input shape to + # convert it to a TensorShape. The conversion function requires a + # tuple which is why we cast the shape. + self.layer.build(tuple(child_input_shape)) self.layer.built = True super(TimeDistributed, self).build() self.built = True @@ -249,7 +261,8 @@ class Bidirectional(Wrapper): they will be returned as a list. Raises: - ValueError: In case of invalid `merge_mode` argument. + ValueError: If not initialized with a `Layer` instance or + In case of invalid `merge_mode` argument. Examples: @@ -265,6 +278,10 @@ class Bidirectional(Wrapper): """ def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `Bidirectional` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: raise ValueError('Invalid merge mode. ' 'Merge mode should be one of ' diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 5eab6aba8a5f9a7e70f55685a9cd9ae6e0cf024d..c8f0d216e6f7a3bb715286bd6e7975a5dc1ac1cc 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -23,8 +23,10 @@ import copy import numpy as np from tensorflow.python import keras +from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_util from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -69,7 +71,7 @@ class _RNNCellWithConstants(keras.layers.Layer): class TimeDistributedTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( @@ -85,6 +87,10 @@ class TimeDistributedTest(test.TestCase): # test config model.get_config() + checkpointed_objects = set(checkpointable_util.list_objects(model)) + for v in model.variables: + self.assertIn(v, checkpointed_objects) + def test_timedistributed_static_batch_size(self): model = keras.models.Sequential() model.add( @@ -97,6 +103,13 @@ class TimeDistributedTest(test.TestCase): epochs=1, batch_size=10) + def test_timedistributed_invalid_init(self): + x = constant_op.constant(np.zeros((1, 1)).astype('float32')) + with self.assertRaisesRegexp( + ValueError, + 'Please initialize `TimeDistributed` layer with a `Layer` instance.'): + keras.layers.TimeDistributed(x) + def test_timedistributed_conv2d(self): with self.test_session(): model = keras.models.Sequential() @@ -220,6 +233,13 @@ class BidirectionalTest(test.TestCase): model = keras.models.model_from_json(model.to_json()) model.summary() + def test_bidirectional_invalid_init(self): + x = constant_op.constant(np.zeros((1, 1)).astype('float32')) + with self.assertRaisesRegexp( + ValueError, + 'Please initialize `Bidirectional` layer with a `Layer` instance.'): + keras.layers.Bidirectional(x) + def test_bidirectional_weight_loading(self): rnn = keras.layers.SimpleRNN samples = 2 @@ -424,6 +444,42 @@ class BidirectionalTest(test.TestCase): layer.trainable = True assert len(layer.trainable_weights) == 6 + def test_Bidirectional_updates(self): + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + x_reachable_update = x * x + layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3)) + _ = layer(x) + assert not layer.updates + assert not layer.get_updates_for(None) + assert not layer.get_updates_for(x) + layer.forward_layer.add_update(x_reachable_update, inputs=x) + layer.forward_layer.add_update(1, inputs=None) + layer.backward_layer.add_update(x_reachable_update, inputs=x) + layer.backward_layer.add_update(1, inputs=None) + assert len(layer.updates) == 4 + assert len(layer.get_updates_for(None)) == 2 + assert len(layer.get_updates_for(x)) == 2 + + def test_Bidirectional_losses(self): + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + x_reachable_loss = x * x + layer = keras.layers.Bidirectional( + keras.layers.SimpleRNN( + 3, kernel_regularizer='l1', bias_regularizer='l1')) + _ = layer(x) + assert len(layer.losses) == 4 + assert len(layer.get_losses_for(None)) == 4 + assert not layer.get_losses_for(x) + layer.forward_layer.add_loss(x_reachable_loss, inputs=x) + layer.forward_layer.add_loss(1, inputs=None) + layer.backward_layer.add_loss(x_reachable_loss, inputs=x) + layer.backward_layer.add_loss(1, inputs=None) + assert len(layer.losses) == 8 + assert len(layer.get_losses_for(None)) == 6 + assert len(layer.get_losses_for(x)) == 2 + def test_Bidirectional_with_constants(self): with self.test_session(): # Test basic case. diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index 8fb957da439dd490bc3378df96f611733335c809..b7e16a41ddaa4fc1f34ffbc0be7150cb10c7a10f 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -173,7 +173,7 @@ def get_nested_model_3(input_dim, num_classes): class ModelSubclassingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_io_workflow_with_np_arrays(self): num_classes = 2 num_samples = 100 @@ -192,7 +192,7 @@ class ModelSubclassingTest(test.TestCase): model.fit(x, y, epochs=2, batch_size=32, verbose=0) _ = model.evaluate(x, y, verbose=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_multi_io_workflow_with_np_arrays(self): num_classes = (2, 3) num_samples = 1000 @@ -251,7 +251,7 @@ class ModelSubclassingTest(test.TestCase): model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0) _ = model.evaluate(steps=10, verbose=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_io_workflow_with_dataset_iterators(self): num_classes = 2 num_samples = 10 @@ -325,7 +325,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.inputs), 2) self.assertEqual(len(model.outputs), 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_updates(self): # test that updates get run during training num_samples = 100 @@ -352,7 +352,74 @@ class ModelSubclassingTest(test.TestCase): y_new = model.predict(x) self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1) - @test_util.run_in_graph_and_eager_modes() + def test_updates_and_losses_for_nested_models_in_subclassed_model(self): + + # Case 1: deferred-build sequential nested in subclass. + class TestModel1(keras.Model): + + def __init__(self): + super(TestModel1, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = keras.Sequential([keras.layers.BatchNormalization(axis=1)]) + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel1() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + # Case 2: placeholder-sequential nested in subclass. + class TestModel2(keras.Model): + + def __init__(self): + super(TestModel2, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = keras.Sequential( + [keras.layers.BatchNormalization(axis=1, input_shape=(10,))]) + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel2() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + # Case 3: functional-API model nested in subclass. + inputs = keras.Input((10,)) + outputs = keras.layers.BatchNormalization(axis=1)(inputs) + bn = keras.Model(inputs, outputs) + + class TestModel3(keras.Model): + + def __init__(self): + super(TestModel3, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = bn + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel3() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + @test_util.run_in_graph_and_eager_modes def test_training_and_inference_behavior(self): # test that dropout is applied in training and not inference @@ -380,7 +447,7 @@ class ModelSubclassingTest(test.TestCase): loss = model.train_on_batch(x, y) self.assertGreater(loss, 0.1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_training_methods(self): # test fit, train_on_batch # on different input types: list, dict @@ -433,14 +500,14 @@ class ModelSubclassingTest(test.TestCase): model = MultiIOTestModel(num_classes=num_classes, use_bn=True) model.predict_on_batch([x1, x2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainable_mutation(self): # test that you can change `trainable` on a model or layer, and that # it freezes the model state during training # TODO(fchollet): add test after we unify BN behavior in eager and symbolic. pass - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_saving(self): num_classes = (2, 3) @@ -482,7 +549,7 @@ class ModelSubclassingTest(test.TestCase): self.assertAllClose(y_ref_1, y1, atol=1e-5) self.assertAllClose(y_ref_2, y2, atol=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_summary(self): class ToString(object): @@ -508,7 +575,7 @@ class ModelSubclassingTest(test.TestCase): model.summary(print_fn=print_fn) self.assertTrue('Trainable params: 587' in print_fn.contents) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_subclass_nested_in_subclass(self): num_classes = 2 num_samples = 100 @@ -531,7 +598,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.trainable_weights), 6 + len(model.test_net.trainable_weights)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_graph_nested_in_subclass(self): num_classes = 2 num_samples = 100 @@ -554,7 +621,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.trainable_weights), 6 + len(model.test_net.trainable_weights)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_subclass_nested_in_graph(self): num_classes = 2 num_samples = 100 @@ -576,7 +643,7 @@ class ModelSubclassingTest(test.TestCase): len(model.non_trainable_weights), 4) self.assertEqual(len(model.trainable_weights), 12) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_support_for_manual_training_arg(self): # In most cases, the `training` argument is left unspecified, in which # case it defaults to value corresponding to the Model method being used @@ -685,7 +752,7 @@ class CustomCallModel(keras.Model): class CustomCallSignatureTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_no_inputs_in_signature(self): model = CustomCallModel() first = array_ops.ones([2, 3]) @@ -699,7 +766,7 @@ class CustomCallSignatureTests(test.TestCase): output = model(first, second=second, training=False) self.assertAllClose(expected_output, self.evaluate(output)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_inputs_in_signature(self): class HasInputsAndOtherPositional(keras.Model): @@ -716,7 +783,7 @@ class CustomCallSignatureTests(test.TestCase): x1, x2 = keras.Input((1, 1)), keras.Input((1, 1)) model(x1, x2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_kwargs_in_signature(self): class HasKwargs(keras.Model): @@ -730,7 +797,7 @@ class CustomCallSignatureTests(test.TestCase): if not context.executing_eagerly(): six.assertCountEqual(self, [arg], model.inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_args_in_signature(self): class HasArgs(keras.Model): diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index e6e45902a8f117e5765249da18afa7cc35aa6b16..ad3819e6e730b48e294b340d39fddeb6d7f2d6bf 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -129,7 +129,7 @@ class TestModelCloning(test.TestCase): class CheckpointingTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_optimizer_dependency(self): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_shape=(4,))) diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py index f58aeaea1acae2717f00a0323b5ff297a8cc8b46..b02cafcf61fea5515d9139371fc41548ff3b87e7 100644 --- a/tensorflow/python/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -19,57 +19,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - import six from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import tracking as checkpointable from tensorflow.python.util.tf_export import tf_export -def clip_norm(g, c, n): - """Clip a tensor by norm. - - Arguments: - g: gradient tensor to clip. - c: clipping threshold. - n: norm of gradient tensor. - - Returns: - Clipped gradient tensor. - """ - if c > 0: - condition = n >= c - then_expression = lambda: math_ops.scalar_mul(c / n, g) - else_expression = lambda: g - - # saving the shape to avoid converting sparse tensor to dense - if isinstance(g, ops.Tensor): - g_shape = copy.copy(g.get_shape()) - elif isinstance(g, ops.IndexedSlices): - g_shape = copy.copy(g.dense_shape) - if condition.dtype != dtypes_module.bool: - condition = math_ops.cast(condition, 'bool') - g = control_flow_ops.cond(condition, then_expression, else_expression) - if isinstance(g, ops.Tensor): - g.set_shape(g_shape) - elif isinstance(g, ops.IndexedSlices): - g._dense_shape = g_shape # pylint: disable=protected-access - return g - - @tf_export('keras.optimizers.Optimizer') class Optimizer(object): """Abstract optimizer base class. @@ -91,6 +56,9 @@ class Optimizer(object): if k not in allowed_kwargs: raise TypeError('Unexpected keyword argument ' 'passed to optimizer: ' + str(k)) + # checks that clipnorm >= 0 and clipvalue >= 0 + if kwargs[k] < 0: + raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k])) self.__dict__.update(kwargs) self.updates = [] self.weights = [] @@ -119,12 +87,13 @@ class Optimizer(object): 'gradient defined (i.e. are differentiable). ' 'Common ops without gradient: ' 'K.argmax, K.round, K.eval.') - if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = K.sqrt( - sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads])) - grads = [clip_norm(g, self.clipnorm, norm) for g in grads] - if hasattr(self, 'clipvalue') and self.clipvalue > 0: - grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] + if hasattr(self, 'clipnorm'): + grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] + if hasattr(self, 'clipvalue'): + grads = [ + clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) + for g in grads + ] return grads def set_weights(self, weights): diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py index 92b0cf326158adb1c6124384571a075196dbd3cc..55fc3fdcf47b4e5589e2253fffdc97d33f5b481b 100644 --- a/tensorflow/python/keras/optimizers_test.py +++ b/tensorflow/python/keras/optimizers_test.py @@ -145,6 +145,12 @@ class KerasOptimizersTest(test.TestCase): with self.assertRaises(NotImplementedError): optimizer.from_config(None) + def test_negative_clipvalue_or_clipnorm(self): + with self.assertRaises(ValueError): + _ = keras.optimizers.SGD(lr=0.01, clipvalue=-0.5) + with self.assertRaises(ValueError): + _ = keras.optimizers.Adam(clipnorm=-2.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index e7cb45d5e110dcb749ae2b1b86dd8dd5b8ded4ef..17aba7d86c236d9bb30d3a3376b3aac40b69e77d 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import OrderedDict import numpy as np from tensorflow.python import keras @@ -183,3 +184,76 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, # for further checks in the caller function return actual_output + + +def _combine_named_parameters(**kwargs): + """Generate combinations based on its keyword arguments. + + Two sets of returned combinations can be concatenated using +. Their product + can be computed using `times()`. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + if not kwargs: + return [OrderedDict()] + + sort_by_key = lambda k: k[0][0] + kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) + first = list(kwargs.items())[0] + + rest = dict(list(kwargs.items())[1:]) + rest_combined = _combine_named_parameters(**rest) + + key = first[0] + values = first[1] + if not isinstance(values, list): + values = [values] + + combinations = [ + OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) + for v in values + for combined in rest_combined + ] + return combinations + + +def generate_combinations_with_testcase_name(**kwargs): + """Generate combinations based on its keyword arguments using combine(). + + This function calls combine() and appends a testcase name to the list of + dictionaries returned. The 'testcase_name' key is a required for named + parameterized tests. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + combinations = _combine_named_parameters(**kwargs) + named_combinations = [] + for combination in combinations: + assert isinstance(combination, OrderedDict) + name = ''.join([ + '_{}_{}'.format( + ''.join(filter(str.isalnum, key)), + ''.join(filter(str.isalnum, str(value)))) + for key, value in combination.items() + ]) + named_combinations.append( + OrderedDict( + list(combination.items()) + [('testcase_name', + '_test{}'.format(name))])) + + return named_combinations + diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 88daff0461593f6270f3be8c06a277c7e6751286..1f28c59ea41a96461a7faba2c41f5e65e6af0180 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -26,6 +26,47 @@ from tensorflow.python.keras.utils.conv_utils import convert_kernel from tensorflow.python.util.tf_export import tf_export +def get_source_inputs(tensor, layer=None, node_index=None): + """Returns the list of input tensors necessary to compute `tensor`. + + Output will always be a list of tensors + (potentially with 1 element). + + Arguments: + tensor: The tensor to start from. + layer: Origin layer of the tensor. Will be + determined via tensor._keras_history if not provided. + node_index: Origin node index of the tensor. + + Returns: + List of input tensors. + """ + if not hasattr(tensor, '_keras_history'): + return tensor + + if layer is None or node_index: + layer, node_index, _ = tensor._keras_history + if not layer._inbound_nodes: + return [tensor] + else: + node = layer._inbound_nodes[node_index] + if not node.inbound_layers: + # Reached an Input layer, stop recursion. + return node.input_tensors + else: + source_tensors = [] + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + previous_sources = get_source_inputs(x, layer, node_index) + # Avoid input redundancy. + for x in previous_sources: + if x not in source_tensors: + source_tensors.append(x) + return source_tensors + + def count_params(weights): """Count the total number of scalars composing the weights. diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 5d29c2e5f86bd3c4997cc3f18f4cb760dc87d63b..8a6614c8371744351b352243476ab1877b84b637 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -893,6 +893,7 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", + "//tensorflow/python:sparse_grad", "//tensorflow/python:sparse_ops", ], ) @@ -3087,3 +3088,22 @@ tf_py_test( data = [":invalid_op.so"], tags = ["no_pip"], ) + +tf_py_test( + name = "cond_v2_test", + size = "small", + srcs = ["cond_v2_test.py"], + additional_deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:cond_v2", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:training", + ], + grpc_enabled = True, +) diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 08bf2d9c644bcde2a80e6138557dae6e19383dfd..40567571e6d259eff3f013c67d1d1f9504fcb9e4 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1006,7 +1006,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase): class ShapeSizeRankTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDenseShape(self): t_value = [[0, 42], [24, 0]] self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value))) @@ -1018,7 +1018,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase): self.assertEqual(4, self.evaluate(array_ops.size(t))) self.assertEqual(2, self.evaluate(array_ops.rank(t))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSparseShape(self): sp_value = sparse_tensor.SparseTensorValue( indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2)) @@ -1031,7 +1031,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase): self.assertEqual(4, self.evaluate(array_ops.size(sp))) self.assertEqual(2, self.evaluate(array_ops.rank(sp))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSizeDtype(self): tensor = [1] self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype) @@ -1123,7 +1123,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): class ConcatSliceResourceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConcatSlice(self): r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b") r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c") diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py index 0ef08581c9f931b991ef0c1218dc503345e248c2..b98e5fd3866cde007c6c00ae0cf04b1f1c46c6f2 100644 --- a/tensorflow/python/kernel_tests/atrous_convolution_test.py +++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py @@ -124,7 +124,7 @@ class AtrousConvolutionTest(test.TestCase): x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW") self.assertEqual(y.shape.as_list(), [1, 20, None, None]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution2D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -139,7 +139,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=dilation_rate, ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution3D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -158,7 +158,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=dilation_rate, ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution1D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -173,7 +173,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=[rate], ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolutionNC(self): if test.is_gpu_available(cuda_only=True): # "NCW" and "NCHW" formats are currently supported only on CUDA. @@ -197,7 +197,7 @@ class AtrousConvolutionTest(test.TestCase): data_format="NCHW", ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousSequence(self): """Tests optimization of sequence of atrous convolutions. diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py index 92cd53a031e73d4ff4ac50c2465f32a2c20545a7..4e31b1ea2a796a2e83696d278cf1b4784d177150 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py @@ -910,7 +910,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): feature_1_values = [11, 27] # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = > - # logit = 0.1*5.0+0.2*5.0+1*5 + # logit = 0.1*1.14+0.2*5.0+1*5 # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = > # logit= 0.1*1.14+0.2*7.0-1*7.0 expected_logits = [[6.114], [-5.486]] @@ -925,5 +925,147 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertAllClose(expected_logits, logits) +class FeatureContribsOpsTest(test_util.TensorFlowTestCase): + """Tests feature contribs ops for model understanding.""" + + def testContribsMultipleTree(self): + """Tests that the contribs work when we have multiple trees.""" + with self.test_session() as session: + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 28 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + original_leaf: {scalar: 2.1} + } + } + nodes { + leaf { + scalar: 1.14 + } + } + nodes { + leaf { + scalar: 8.79 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 26 + left_id: 1 + right_id: 2 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 50 + left_id: 3 + right_id: 4 + } + metadata { + original_leaf: {scalar: 5.5} + } + } + nodes { + leaf { + scalar: 7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + nodes { + leaf { + scalar: 6.0 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 0 + threshold: 34 + left_id: 1 + right_id: 2 + } + } + nodes { + leaf { + scalar: -7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + } + tree_weights: 0.1 + tree_weights: 0.2 + tree_weights: 1.0 + tree_metadata: { + num_layers_grown: 1} + tree_metadata: { + num_layers_grown: 2} + tree_metadata: { + num_layers_grown: 1} + """, tree_ensemble_config) + + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + feature_0_values = [36, 32] + feature_1_values = [13, -29] # Unused. Feature is not in above ensemble. + feature_2_values = [11, 27] + + # Expected logits are computed by traversing the logit path and + # subtracting child logits from parent logits. + bias = 2.1 * 0.1 # Root node of tree_0. + expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0)) + # example_0 : (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114, + # 1.0 * 5.0 + 0.2 * 5. + .114) + # example_1 : (bias, 0.1 * 1.14, 0.2 * 7 + .114, + # 1.0 * -7. + 0.2 * 7 + .114) + expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114), + (bias, 0.114, 1.514, -5.486)) + + bucketized_features = [ + feature_0_values, feature_1_values, feature_2_values + ] + + debug_op = boosted_trees_ops.example_debug_outputs( + tree_ensemble_handle, + bucketized_features=bucketized_features, + logits_dimension=1) + + serialized_examples_debug_outputs = session.run(debug_op) + feature_ids = [] + logits_paths = [] + for example in serialized_examples_debug_outputs: + example_debug_outputs = boosted_trees_pb2.DebugOutput() + example_debug_outputs.ParseFromString(example) + feature_ids.append(example_debug_outputs.feature_ids) + logits_paths.append(example_debug_outputs.logits_path) + + self.assertAllClose(feature_ids, expected_feature_ids) + self.assertAllClose(logits_paths, expected_logits_paths) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 7ef841c96b5cec9c7ae56c631896231ed663b8be..bda6ca5ca91ab1f55c4586f604a116a9b3fed874 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -34,45 +34,45 @@ from tensorflow.python.platform import test class AssertProperIterableTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_tensor_raises(self): tensor = constant_op.constant(1) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(tensor) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_sparse_tensor_raises(self): ten = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(ten) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_ndarray_raises(self): array = np.array([1, 2, 3]) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(array) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_string_raises(self): mystr = "hello" with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(mystr) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_non_iterable_object_raises(self): non_iterable = 1234 with self.assertRaisesRegexp(TypeError, "to be iterable"): check_ops.assert_proper_iterable(non_iterable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_list_does_not_raise(self): list_of_stuff = [ constant_op.constant([11, 22]), constant_op.constant([1, 2]) ] check_ops.assert_proper_iterable(list_of_stuff) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_generator_does_not_raise(self): generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant( [1, 2])) @@ -81,14 +81,14 @@ class AssertProperIterableTest(test.TestCase): class AssertEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies([check_ops.assert_equal(small, small)]): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_scalar_comparison(self): const_true = constant_op.constant(True, name="true") const_false = constant_op.constant(False, name="false") @@ -101,7 +101,7 @@ class AssertEqualTest(test.TestCase): x = check_ops.assert_equal(small, small) assert x is None - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): # Static check static_small = constant_op.constant([1, 2], name="small") @@ -179,7 +179,7 @@ First 2 elements of y: check_ops.assert_equal(big, small, message="big does not equal small", summarize=2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): # Static check static_small = constant_op.constant([3, 1], name="small") @@ -196,7 +196,7 @@ First 2 elements of y: with self.assertRaisesOpError("small.*big"): out.eval(feed_dict={small: [3, 1], big: [4, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): small = constant_op.constant([[1, 2], [1, 2]], name="small") small_2 = constant_op.constant([1, 2], name="small_2") @@ -204,7 +204,7 @@ First 2 elements of y: out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") small_2 = constant_op.constant([1, 1], name="small_2") @@ -219,13 +219,13 @@ First 2 elements of y: out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_and_broadcastable_shapes(self): cond = constant_op.constant([True, False], name="small") with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): check_ops.assert_equal(cond, False, message="fail") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -236,7 +236,7 @@ First 2 elements of y: class AssertNoneEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_not_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([10, 20], name="small") @@ -245,7 +245,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([3, 1], name="small") with self.assertRaisesOpError("x != y did not hold"): @@ -254,7 +254,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3], name="big") @@ -263,7 +263,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_but_non_broadcastable_shapes(self): with self.test_session(): small = constant_op.constant([1, 1, 1], name="small") @@ -280,7 +280,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): with self.test_session(): larry = constant_op.constant([]) @@ -300,7 +300,7 @@ class AssertNoneEqualTest(test.TestCase): class AssertAllCloseTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): x = constant_op.constant(1., name="x") y = constant_op.constant(1., name="y") @@ -309,7 +309,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_32_bit_due_to_default_rtol(self): eps = np.finfo(np.float32).eps # Default rtol/atol is 10*eps @@ -320,7 +320,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_32_bit_due_to_default_atol(self): eps = np.finfo(np.float32).eps # Default rtol/atol is 10*eps @@ -331,7 +331,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_64_bit_due_to_default_rtol(self): eps = np.finfo(np.float64).eps # Default rtol/atol is 10*eps @@ -342,7 +342,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_64_bit_due_to_default_atol(self): eps = np.finfo(np.float64).eps # Default rtol/atol is 10*eps @@ -353,7 +353,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_due_to_custom_rtol(self): x = constant_op.constant(1., name="x") y = constant_op.constant(1.1, name="y") @@ -363,7 +363,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_due_to_custom_atol(self): x = constant_op.constant(0., name="x") y = constant_op.constant(0.1, name="y", dtype=np.float32) @@ -373,7 +373,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -381,7 +381,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(larry) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_atol_violated(self): x = constant_op.constant(10., name="x") y = constant_op.constant(10.2, name="y") @@ -392,7 +392,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_default_rtol_violated(self): x = constant_op.constant(0.1, name="x") y = constant_op.constant(0.0, name="y") @@ -412,7 +412,7 @@ class AssertAllCloseTest(test.TestCase): class AssertLessTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([1, 2], name="small") with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"): @@ -422,7 +422,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -431,7 +431,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([4, 2], name="big") @@ -439,7 +439,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 2], name="big") @@ -447,7 +447,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([3, 2], name="big") @@ -462,7 +462,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -480,7 +480,7 @@ class AssertLessTest(test.TestCase): class AssertLessEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies( @@ -488,7 +488,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -499,7 +499,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 2], name="big") @@ -507,7 +507,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 1], name="big") @@ -515,7 +515,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([1, 1, 1], name="big") @@ -531,7 +531,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -543,7 +543,7 @@ class AssertLessEqualTest(test.TestCase): class AssertGreaterTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([1, 2], name="small") with self.assertRaisesOpError("fail"): @@ -553,7 +553,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -562,7 +562,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(big) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([4, 2], name="big") @@ -570,7 +570,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 2], name="big") @@ -578,7 +578,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([3, 2], name="big") @@ -593,7 +593,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -604,7 +604,7 @@ class AssertGreaterTest(test.TestCase): class AssertGreaterEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies( @@ -612,7 +612,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -623,7 +623,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 2], name="big") @@ -632,7 +632,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 1], name="big") @@ -641,7 +641,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="big") big = constant_op.constant([3, 1], name="small") @@ -657,7 +657,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -669,14 +669,14 @@ class AssertGreaterEqualTest(test.TestCase): class AssertNegativeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_negative(self): frank = constant_op.constant([-1, -2], name="frank") with ops.control_dependencies([check_ops.assert_negative(frank)]): out = array_ops.identity(frank) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_positive(self): doug = constant_op.constant([1, 2], name="doug") with self.assertRaisesOpError("fail"): @@ -686,7 +686,7 @@ class AssertNegativeTest(test.TestCase): out = array_ops.identity(doug) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_zero(self): claire = constant_op.constant([0], name="claire") with self.assertRaisesOpError("x < 0 did not hold"): @@ -694,7 +694,7 @@ class AssertNegativeTest(test.TestCase): out = array_ops.identity(claire) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is negative when it satisfies: # For every element x_i in x, x_i < 0 @@ -708,7 +708,7 @@ class AssertNegativeTest(test.TestCase): class AssertPositiveTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_negative(self): freddie = constant_op.constant([-1, -2], name="freddie") with self.assertRaisesOpError("fail"): @@ -718,14 +718,14 @@ class AssertPositiveTest(test.TestCase): out = array_ops.identity(freddie) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_positive(self): remmy = constant_op.constant([1, 2], name="remmy") with ops.control_dependencies([check_ops.assert_positive(remmy)]): out = array_ops.identity(remmy) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_zero(self): meechum = constant_op.constant([0], name="meechum") with self.assertRaisesOpError("x > 0 did not hold"): @@ -733,7 +733,7 @@ class AssertPositiveTest(test.TestCase): out = array_ops.identity(meechum) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is positive when it satisfies: # For every element x_i in x, x_i > 0 @@ -747,7 +747,7 @@ class AssertPositiveTest(test.TestCase): class AssertRankTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 1 @@ -768,7 +768,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("fail.*my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 0 @@ -784,7 +784,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 0 @@ -802,7 +802,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 1 @@ -818,7 +818,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 2 @@ -836,7 +836,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_scalar_static(self): tensor = constant_op.constant([1, 2], name="my_tensor") with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"): @@ -852,7 +852,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, rank_tensor)]): array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_integer_static(self): tensor = constant_op.constant([1, 2], name="my_tensor") with self.assertRaisesRegexp(TypeError, @@ -873,7 +873,7 @@ class AssertRankTest(test.TestCase): class AssertRankInTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self): tensor_rank0 = constant_op.constant(42, name="my_tensor") with self.assertRaisesRegexp( @@ -890,7 +890,7 @@ class AssertRankInTest(test.TestCase): with self.assertRaisesOpError("fail.*my_tensor.*rank"): array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self): tensor_rank0 = constant_op.constant(42, name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): @@ -906,7 +906,7 @@ class AssertRankInTest(test.TestCase): check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self): tensor_rank1 = constant_op.constant([42, 43], name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): @@ -924,7 +924,7 @@ class AssertRankInTest(test.TestCase): tensor_rank1: (42.0, 43.0) }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self): tensor_rank1 = constant_op.constant((42, 43), name="my_tensor") with self.assertRaisesRegexp(ValueError, "rank"): @@ -942,7 +942,7 @@ class AssertRankInTest(test.TestCase): tensor_rank1: (42.0, 43.0) }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_scalar_static(self): tensor = constant_op.constant((42, 43), name="my_tensor") desired_ranks = ( @@ -966,7 +966,7 @@ class AssertRankInTest(test.TestCase): desired_ranks[1]: [2, 1], }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_integer_static(self): tensor = constant_op.constant((42, 43), name="my_tensor") with self.assertRaisesRegexp(TypeError, @@ -987,7 +987,7 @@ class AssertRankInTest(test.TestCase): class AssertRankAtLeastTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 1 @@ -1005,7 +1005,7 @@ class AssertRankAtLeastTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 0 @@ -1021,7 +1021,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 0 @@ -1037,7 +1037,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 1 @@ -1053,7 +1053,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 2 @@ -1074,7 +1074,7 @@ class AssertRankAtLeastTest(test.TestCase): class AssertNonNegativeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_negative(self): zoe = constant_op.constant([-1, -2], name="zoe") with self.assertRaisesOpError("x >= 0 did not hold"): @@ -1082,14 +1082,14 @@ class AssertNonNegativeTest(test.TestCase): out = array_ops.identity(zoe) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_zero_and_positive(self): lucas = constant_op.constant([0, 2], name="lucas") with ops.control_dependencies([check_ops.assert_non_negative(lucas)]): out = array_ops.identity(lucas) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is non-negative when it satisfies: # For every element x_i in x, x_i >= 0 @@ -1103,14 +1103,14 @@ class AssertNonNegativeTest(test.TestCase): class AssertNonPositiveTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_zero_and_negative(self): tom = constant_op.constant([0, -2], name="tom") with ops.control_dependencies([check_ops.assert_non_positive(tom)]): out = array_ops.identity(tom) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_positive(self): rachel = constant_op.constant([0, 2], name="rachel") with self.assertRaisesOpError("x <= 0 did not hold"): @@ -1118,7 +1118,7 @@ class AssertNonPositiveTest(test.TestCase): out = array_ops.identity(rachel) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is non-positive when it satisfies: # For every element x_i in x, x_i <= 0 @@ -1132,14 +1132,14 @@ class AssertNonPositiveTest(test.TestCase): class AssertIntegerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_integer(self): integers = constant_op.constant([1, 2], name="integers") with ops.control_dependencies([check_ops.assert_integer(integers)]): out = array_ops.identity(integers) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_float(self): floats = constant_op.constant([1.0, 2.0], name="floats") with self.assertRaisesRegexp(TypeError, "Expected.*integer"): @@ -1148,7 +1148,7 @@ class AssertIntegerTest(test.TestCase): class AssertTypeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_correct_type(self): integers = constant_op.constant([1, 2], dtype=dtypes.int64) with ops.control_dependencies([ @@ -1156,7 +1156,7 @@ class AssertTypeTest(test.TestCase): out = array_ops.identity(integers) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_wrong_type(self): floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16) with self.assertRaisesRegexp(TypeError, "must be of type.*float32"): @@ -1165,74 +1165,74 @@ class AssertTypeTest(test.TestCase): class IsStrictlyIncreasingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_constant_tensor_is_not_strictly_increasing(self): self.assertFalse(self.evaluate(check_ops.is_strictly_increasing([1, 1, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decreasing_tensor_is_not_strictly_increasing(self): self.assertFalse(self.evaluate( check_ops.is_strictly_increasing([1, 0, -1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_2d_decreasing_tensor_is_not_strictly_increasing(self): self.assertFalse( self.evaluate(check_ops.is_strictly_increasing([[1, 3], [2, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_tensor_is_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1, 2, 3]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_two_tensor(self): self.assertTrue( self.evaluate(check_ops.is_strictly_increasing([[-1, 2], [3, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensor_with_one_element_is_strictly_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_is_strictly_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([]))) class IsNonDecreasingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_constant_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 1, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decreasing_tensor_is_not_non_decreasing(self): self.assertFalse(self.evaluate(check_ops.is_non_decreasing([3, 2, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_2d_decreasing_tensor_is_not_non_decreasing(self): self.assertFalse(self.evaluate( check_ops.is_non_decreasing([[1, 3], [2, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_one_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 2, 3]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_two_tensor(self): self.assertTrue(self.evaluate( check_ops.is_non_decreasing([[-1, 2], [3, 3]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensor_with_one_element_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([]))) class FloatDTypeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_assert_same_float_dtype(self): self.assertIs(dtypes.float32, check_ops.assert_same_float_dtype(None, None)) @@ -1286,7 +1286,7 @@ class FloatDTypeTest(test.TestCase): class AssertScalarTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_assert_scalar(self): check_ops.assert_scalar(constant_op.constant(3)) check_ops.assert_scalar(constant_op.constant("foo")) diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py similarity index 71% rename from tensorflow/contrib/control_flow/python/cond_v2_test.py rename to tensorflow/python/kernel_tests/cond_v2_test.py index 94ed3e130ba06c129d96c4ea775a043b5bc9b3ea..759db5d5f43a144150918446e6ce206b3095904f 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -19,11 +19,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.control_flow.python import cond_v2 +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl @@ -37,15 +38,15 @@ from tensorflow.python.util import compat class NewCondTest(test.TestCase): def _testCond(self, true_fn, false_fn, train_vals): - pred = array_ops.placeholder(dtypes.bool, name="pred") + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") - expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") - actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") - expected_grad = gradients_impl.gradients(expected, train_vals) - actual_grad = gradients_impl.gradients(actual, train_vals) + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) - with self.test_session() as sess: expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), {pred: True}) self.assertEqual(expected_val, actual_val) @@ -85,22 +86,22 @@ class NewCondTest(test.TestCase): self._testCond(true_fn, false_fn, [y]) def testNoInputs(self): - pred = array_ops.placeholder(dtypes.bool, name="pred") + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") - def true_fn(): - return constant_op.constant(1.0) + def true_fn(): + return constant_op.constant(1.0) - def false_fn(): - return constant_op.constant(2.0) + def false_fn(): + return constant_op.constant(2.0) - out = cond_v2.cond_v2(pred, true_fn, false_fn) + out = cond_v2.cond_v2(pred, true_fn, false_fn) - with self.test_session() as sess: self.assertEqual(sess.run(out, {pred: True}), [1.0]) self.assertEqual(sess.run(out, {pred: False}), [2.0]) def _createCond(self, name): - pred = array_ops.placeholder(dtypes.bool, name="pred") + pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): @@ -131,20 +132,20 @@ class NewCondTest(test.TestCase): self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions) def testSecondDerivative(self): - pred = array_ops.placeholder(dtypes.bool, name="pred") - x = constant_op.constant(3.0, name="x") + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") - def true_fn(): - return math_ops.pow(x, 3) + def true_fn(): + return math_ops.pow(x, 3) - def false_fn(): - return x + def false_fn(): + return x - cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") - cond_grad = gradients_impl.gradients(cond, [x]) - cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) - with self.test_session() as sess: # d[x^3]/dx = 3x^2 true_val = sess.run(cond_grad, {pred: True}) self.assertEqual(true_val, [27.0]) @@ -178,14 +179,14 @@ class NewCondTest(test.TestCase): meta_graph = saver.export_meta_graph() with ops.Graph().as_default() as g: - saver.import_meta_graph(meta_graph) - x = ops.get_collection("x")[0] - pred = ops.get_collection("pred")[0] - cond = ops.get_collection("cond") - cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") - cond_grad_grad = gradients_impl.gradients( - cond_grad, [x], name="cond_grad_grad") with self.test_session(graph=g) as sess: + saver.import_meta_graph(meta_graph) + x = ops.get_collection("x")[0] + pred = ops.get_collection("pred")[0] + cond = ops.get_collection("cond") + cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") + cond_grad_grad = gradients_impl.gradients( + cond_grad, [x], name="cond_grad_grad") # d[x^3]/dx = 3x^2 true_val = sess.run(cond_grad, {pred: True}) self.assertEqual(true_val, [27.0]) @@ -200,6 +201,65 @@ class NewCondTest(test.TestCase): # d2[x]/dx2 = 0 self.assertEqual(false_val, [0.0]) + def testLowering(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + out_cond = self._createCond("cond") + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # If lowering was enabled, there should be a `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue(switch_found, + "A `Switch` op should exist if the graph was lowered.") + + # If lowering was enabled, there should be no `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse(if_found, + "An `If` op was found, but it should be lowered.") + + def testLoweringDisabledInXLA(self): + with self.test_session(graph=ops.Graph()) as sess: + # Build the cond_v2 in an XLA context + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + out_cond = self._createCond("cond") + xla_context.Exit() + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # Lowering disabled in XLA, there should be no `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse( + switch_found, + "A `Switch` op exists, but the graph should not be lowered.") + + # Lowering disabled in XLA, there should still be an `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue( + if_found, + "An `If` op was not found, but the graph should not be lowered.") + class CondV2CollectionTest(test.TestCase): @@ -387,6 +447,34 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): d = constant_op.constant([2.0], name="d") self.assertEqual([b"loc:@a"], d.op.colocation_groups()) + def testColocateWithInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + with ops.device("/device:CPU:1"): + b = constant_op.constant([2.0], name="b") + + def fn(): + with ops.colocate_with(b.op): + c = math_ops.add(a, a, name="c") + return c + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + # We expect there to be two partitions because of the + # colocate_with. We are only running the cond, which has a data + # dependency on `a` but not on `b`. So, without the colocate_with + # we would expect execution on just one device. + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + def testDeviceBeforeCond(self): with ops.Graph().as_default() as g: with self.test_session(graph=g): @@ -421,5 +509,28 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase): d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device) + def testDeviceInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + def fn(): + with ops.device("/device:CPU:1"): + c = math_ops.add(a, a, name="c") + return c + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index 79e419867d70071280b7c88b6bfa820b935b24cd..ae6875340e776fc6808be3f4afeb59644245c886 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import test class ConfusionMatrixTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExample(self): """This is a test of the example provided in pydoc.""" with self.test_session(): diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 80ba7dafc9d03c9cef472343fbcbba8975c1557e..474d06b8f3a4276c65711d74ba0d1db6fb06cbf9 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -345,7 +345,7 @@ class Conv2DTest(test.TestCase): self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol) self.assertShapeEqual(value, conv) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x1Filter(self): expected_output = [ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, @@ -358,7 +358,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter2x1Dilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 4, 4, 1], @@ -367,7 +367,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmpty(self): expected_output = [] self._VerifyValues( @@ -377,7 +377,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[0, 2, 3, 3], @@ -386,7 +386,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] @@ -397,7 +397,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 2, 3, 3], @@ -406,7 +406,7 @@ class Conv2DTest(test.TestCase): dilations=[1, 2], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [ @@ -420,7 +420,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x2FilterDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 2, 3, 3], @@ -429,7 +429,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride2(self): expected_output = [2271.0, 2367.0, 2463.0] self._VerifyValues( @@ -439,7 +439,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride2Same(self): expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] self._VerifyValues( @@ -449,7 +449,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride1x2(self): expected_output = [58.0, 78.0, 98.0, 118.0, 138.0, 158.0] self._VerifyValues( @@ -459,7 +459,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSmallerThanStrideValid(self): expected_output = [65, 95, 275, 305] self._VerifyValues( @@ -469,7 +469,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSmallerThanStrideSame(self): self._VerifyValues( tensor_in_sizes=[1, 3, 3, 1], @@ -492,7 +492,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=[44, 28, 41, 16]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSize(self): self._VerifyValues( tensor_in_sizes=[1, 2, 2, 1], @@ -501,7 +501,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=[50, 60]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 3, 3, 1], @@ -589,7 +589,7 @@ class Conv2DTest(test.TestCase): for i in range(1, len(values)): self.assertAllClose(values[0], values[i], rtol=1e-2, atol=1e-2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth1ValidBackpropInput(self): expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -604,7 +604,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyBackpropInput(self): expected_output = [] for (data_format, use_gpu) in GetTestConfigs(): @@ -619,7 +619,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropInput(self): expected_output = [ 14.0, 32.0, 50.0, 100.0, 163.0, 226.0, 167.0, 212.0, 257.0, 122.0, @@ -639,7 +639,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropInputStride1x2(self): expected_output = [ 1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 7.0, 12.0, 11.0, 18.0, 15.0, 24.0, 12.0, @@ -657,7 +657,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DStrideTwoFilterOneSameBackpropInput(self): expected_output = [ 1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0, @@ -675,7 +675,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeBackpropInput(self): expected_output = [5.0, 11.0, 17.0, 23.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -759,7 +759,7 @@ class Conv2DTest(test.TestCase): for i in range(1, len(values)): self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth1ValidBackpropFilter(self): expected = [5.0, 8.0, 14.0, 17.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -773,7 +773,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyBackpropFilter(self): expected = [] for (data_format, use_gpu) in GetTestConfigs(): @@ -787,7 +787,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DBackpropFilterWithEmptyInput(self): expected = [0, 0, 0, 0] for (data_format, use_gpu) in GetTestConfigs(): @@ -801,7 +801,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropFilter(self): expected = [ 17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0, 32.0, 43.0, 54.0, @@ -820,7 +820,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self): expected = [161.0, 182.0, 287.0, 308.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -834,7 +834,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DStrideTwoFilterOneSameBackpropFilter(self): expected_output = [78.] for (data_format, use_gpu) in GetTestConfigs(): @@ -848,7 +848,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeBackpropFilter(self): expected_output = [1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 4.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -1897,19 +1897,19 @@ if __name__ == "__main__": for index, (input_size_, filter_size_, output_size_, stride_, padding_) in enumerate(GetShrunkInceptionShapes()): setattr(Conv2DTest, "testInceptionFwd_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))) setattr( Conv2DTest, "testInceptionFwdDilatedConv_" + str(index), - test_util.run_in_graph_and_eager_modes()(GetInceptionFwdDilatedConvTest( + test_util.run_in_graph_and_eager_modes(GetInceptionFwdDilatedConvTest( input_size_, filter_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackInput_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackInputTest(input_size_, filter_size_, output_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackFilter_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackFilterTest(input_size_, filter_size_, output_size_, [stride_, stride_], padding_))) @@ -1924,17 +1924,17 @@ if __name__ == "__main__": fshape = [1, 1, 1, 256] oshape = [1, 400, 400, 256] setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))) setattr(Conv2DTest, "testInceptionFwdDilatedConv_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdDilatedConvTest(ishape, fshape, 1, "SAME"))) setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME", gpu_only=True))) setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME", gpu_only=True))) test.main() diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 8a3e64b174d9572ce11981a9d8d0e71cd9a336bc..b61232cdedecacf0cc0f9b1661486a52afc86c2e 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -96,7 +96,8 @@ class UnaryOpTest(test.TestCase): np_ans = np_func(x) with self.test_session(use_gpu=False): inx = ops.convert_to_tensor(x) - if x.dtype in (np.float32, np.float64): + if x.dtype in (np.float32, np.float64, + dtypes_lib.bfloat16.as_numpy_dtype): y = 1.1 * tf_func(inx) np_ans *= 1.1 else: @@ -105,6 +106,8 @@ class UnaryOpTest(test.TestCase): self.assertShapeEqual(np_ans, y) if x.dtype == np.float16: self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3) + elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype: + self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2) else: self.assertAllClose(np_ans, tf_cpu) @@ -668,12 +671,11 @@ class BinaryOpTest(test.TestCase): self._compareCpu(x, y, np_func, tf_func, also_compare_variables) if x.dtype in (np.float16, np.float32, np.float64, np.complex64, np.complex128): - if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.igamma, - math_ops.igammac, math_ops.zeta, math_ops.polygamma): + if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta, + math_ops.polygamma): self._compareGradientX(x, y, np_func, tf_func) self._compareGradientY(x, y, np_func, tf_func) - if tf_func in (math_ops.igamma, math_ops.igammac, math_ops.zeta, - math_ops.polygamma): + if tf_func in (math_ops.zeta, math_ops.polygamma): # These methods only support gradients in the second parameter self._compareGradientY(x, y, np_func, tf_func) self._compareGpu(x, y, np_func, tf_func) diff --git a/tensorflow/python/kernel_tests/dct_ops_test.py b/tensorflow/python/kernel_tests/dct_ops_test.py index 93b2ff4561bcc8fd13855cde444c4b6237d7949b..97d7e2d8f90a620b693e2c81adc616d399e13bd6 100644 --- a/tensorflow/python/kernel_tests/dct_ops_test.py +++ b/tensorflow/python/kernel_tests/dct_ops_test.py @@ -40,50 +40,92 @@ def try_import(name): # pylint: disable=invalid-name fftpack = try_import("scipy.fftpack") +def _np_dct2(signals, norm=None): + """Computes the DCT-II manually with NumPy.""" + # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 + dct_size = signals.shape[-1] + dct = np.zeros_like(signals) + for k in range(dct_size): + phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) + dct[..., k] = np.sum(signals * phi, axis=-1) + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + if norm == "ortho": + # The orthonormal scaling includes a factor of 0.5 which we combine with + # the overall scaling of 2.0 to cancel. + dct[..., 0] *= np.sqrt(1.0 / dct_size) + dct[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + dct *= 2.0 + return dct + + +def _np_dct3(signals, norm=None): + """Computes the DCT-III manually with NumPy.""" + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + dct_size = signals.shape[-1] + signals = np.array(signals) # make a copy so we can modify + if norm == "ortho": + signals[..., 0] *= np.sqrt(4.0 / dct_size) + signals[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + signals *= 2.0 + dct = np.zeros_like(signals) + # X_k = 0.5 * x_0 + + # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1 + half_x0 = 0.5 * signals[..., 0] + for k in range(dct_size): + phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size) + dct[..., k] = half_x0 + np.sum(signals[..., 1:] * phi, axis=-1) + return dct + + +NP_DCT = {2: _np_dct2, 3: _np_dct3} +NP_IDCT = {2: _np_dct3, 3: _np_dct2} + + class DCTOpsTest(test.TestCase): - def _np_dct2(self, signals, norm=None): - """Computes the DCT-II manually with NumPy.""" - # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 - dct_size = signals.shape[-1] - dct = np.zeros_like(signals) - for k in range(dct_size): - phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) - dct[..., k] = np.sum(signals * phi, axis=-1) - # SciPy's `dct` has a scaling factor of 2.0 which we follow. - # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src - if norm == "ortho": - # The orthonormal scaling includes a factor of 0.5 which we combine with - # the overall scaling of 2.0 to cancel. - dct[..., 0] *= np.sqrt(1.0 / dct_size) - dct[..., 1:] *= np.sqrt(2.0 / dct_size) - else: - dct *= 2.0 - return dct - - def _compare(self, signals, norm, atol=5e-4, rtol=5e-4): - """Compares the DCT to SciPy (if available) and a NumPy implementation.""" - np_dct = self._np_dct2(signals, norm) - tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval() + def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4): + """Compares (I)DCT to SciPy (if available) and a NumPy implementation.""" + np_dct = NP_DCT[dct_type](signals, norm) + tf_dct = spectral_ops.dct(signals, type=dct_type, norm=norm).eval() self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol) + np_idct = NP_IDCT[dct_type](signals, norm) + tf_idct = spectral_ops.idct(signals, type=dct_type, norm=norm).eval() + self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol) if fftpack: - scipy_dct = fftpack.dct(signals, type=2, norm=norm) + scipy_dct = fftpack.dct(signals, type=dct_type, norm=norm) self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) + scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm) + self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol) + # Verify inverse(forward(s)) == s, up to a normalization factor. + tf_idct_dct = spectral_ops.idct( + tf_dct, type=dct_type, norm=norm).eval() + tf_dct_idct = spectral_ops.dct( + tf_idct, type=dct_type, norm=norm).eval() + if norm is None: + tf_idct_dct *= 0.5 / signals.shape[-1] + tf_dct_idct *= 0.5 / signals.shape[-1] + self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol) + self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol) def test_random(self): """Test randomly generated batches of data.""" with spectral_ops_test_util.fft_kernel_label_map(): with self.test_session(use_gpu=True): - for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]): + for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]): signals = np.random.rand(*shape).astype(np.float32) for norm in (None, "ortho"): - self._compare(signals, norm) + self._compare(signals, norm, 2) + self._compare(signals, norm, 3) def test_error(self): signals = np.random.rand(10) # Unsupported type. with self.assertRaises(ValueError): - spectral_ops.dct(signals, type=3) + spectral_ops.dct(signals, type=1) # Unknown normalization. with self.assertRaises(ValueError): spectral_ops.dct(signals, norm="bad") diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index 985922245e975dccc69275627e40b9970562f547..14532965d8c2c62139b3cd922acb9f90c0691d53 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -135,6 +135,10 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "noguitar", # b/110489471 + "notap", # b/110489471 + ], ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py index 095d1cde1530f15fd2a7ff4cb7f56424f276be5a..9ad77a54cbc730296508e4fe74248d2413029151 100644 --- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -22,6 +22,7 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util @@ -57,14 +58,14 @@ def entropy(p): class BernoulliTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testP(self): p = [0.2, 0.4] dist = bernoulli.Bernoulli(probs=p) with self.test_session(): self.assertAllClose(p, self.evaluate(dist.probs)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogits(self): logits = [-42., 42.] dist = bernoulli.Bernoulli(logits=logits) @@ -82,7 +83,7 @@ class BernoulliTest(test.TestCase): with self.test_session(): self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvalidP(self): invalid_ps = [1.01, 2.] for p in invalid_ps: @@ -104,7 +105,7 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(probs=p) self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShapes(self): with self.test_session(): for batch_shape in ([], [1], [2, 3, 4]): @@ -115,7 +116,7 @@ class BernoulliTest(test.TestCase): self.assertAllEqual([], dist.event_shape.as_list()) self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDtype(self): dist = make_bernoulli([]) self.assertEqual(dist.dtype, dtypes.int32) @@ -133,7 +134,7 @@ class BernoulliTest(test.TestCase): self.assertEqual(dist64.dtype, dist64.sample(5).dtype) self.assertEqual(dist64.dtype, dist64.mode().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def _testPmf(self, **kwargs): dist = bernoulli.Bernoulli(**kwargs) with self.test_session(): @@ -174,7 +175,7 @@ class BernoulliTest(test.TestCase): p: [0.2, 0.3, 0.4] }), [[0.2, 0.7, 0.4]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPmfInvalid(self): p = [0.1, 0.2, 0.7] with self.test_session(): @@ -184,7 +185,7 @@ class BernoulliTest(test.TestCase): with self.assertRaisesOpError("Elements cannot exceed 1."): self.evaluate(dist.prob([2, 0, 1])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPmfWithP(self): p = [[0.2, 0.4], [0.3, 0.6]] self._testPmf(probs=p) @@ -226,21 +227,21 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) self.assertEqual((2, 1), dist.log_prob(1).get_shape()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBoundaryConditions(self): with self.test_session(): dist = bernoulli.Bernoulli(probs=1.0) self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEntropyNoBatch(self): p = 0.2 dist = bernoulli.Bernoulli(probs=p) with self.test_session(): self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEntropyWithBatch(self): p = [[0.1, 0.7], [0.2, 0.6]] dist = bernoulli.Bernoulli(probs=p, validate_args=False) @@ -250,7 +251,7 @@ class BernoulliTest(test.TestCase): [[entropy(0.1), entropy(0.7)], [entropy(0.2), entropy(0.6)]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSampleN(self): with self.test_session(): p = [0.2, 0.6] @@ -272,6 +273,16 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(np.log([.2, .4])) self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) + @test_util.run_in_graph_and_eager_modes + def testNotReparameterized(self): + p = constant_op.constant([0.2, 0.6]) + with backprop.GradientTape() as tape: + tape.watch(p) + dist = bernoulli.Bernoulli(probs=p) + samples = dist.sample(100) + grad_p = tape.gradient(samples, p) + self.assertIsNone(grad_p) + def testSampleActsLikeSampleN(self): with self.test_session() as sess: p = [0.2, 0.6] @@ -282,18 +293,18 @@ class BernoulliTest(test.TestCase): self.evaluate(dist.sample(n, seed)), self.evaluate(dist.sample(n, seed))) n = array_ops.placeholder(dtypes.int32) - sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)], - feed_dict={n: 1000}) - self.assertAllEqual(sample, sample) + sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)], + feed_dict={n: 1000}) + self.assertAllEqual(sample1, sample2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMean(self): with self.test_session(): p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) dist = bernoulli.Bernoulli(probs=p) self.assertAllEqual(self.evaluate(dist.mean()), p) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarianceAndStd(self): var = lambda p: p * (1. - p) with self.test_session(): @@ -310,7 +321,7 @@ class BernoulliTest(test.TestCase): [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], dtype=np.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBernoulliBernoulliKL(self): batch_size = 6 a_p = np.array([0.5] * batch_size, dtype=np.float32) diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py index 4bc8303ebb6939f3f8e2637120b6510c225c2f12..36f3ffc333f74e3f6e672b6ba1591bf8de08a010 100644 --- a/tensorflow/python/kernel_tests/distributions/beta_test.py +++ b/tensorflow/python/kernel_tests/distributions/beta_test.py @@ -21,6 +21,7 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape @@ -282,6 +283,18 @@ class BetaTest(test.TestCase): self.assertAllClose( np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) + def testBetaFullyReparameterized(self): + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + with backprop.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + beta = beta_lib.Beta(a, b) + samples = beta.sample(100) + grad_a, grad_b = tape.gradient(samples, [a, b]) + self.assertIsNotNone(grad_a) + self.assertIsNotNone(grad_b) + # Test that sampling with the same seed twice gives the same results. def testBetaSampleMultipleTimes(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py index 68b4ffdb58c2fbcd308da8a25bbe2391c2ed90f6..d8939433ce68ffa561e8e2200826f88dbe283ac2 100644 --- a/tensorflow/python/kernel_tests/distributions/categorical_test.py +++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util @@ -376,6 +377,15 @@ class CategoricalTest(test.TestCase, parameterized.TestCase): self.assertAllClose( [0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2) + def testNotReparameterized(self): + p = constant_op.constant([0.3, 0.3, 0.4]) + with backprop.GradientTape() as tape: + tape.watch(p) + dist = categorical.Categorical(p) + samples = dist.sample(100) + grad_p = tape.gradient(samples, p) + self.assertIsNone(grad_p) + def testLogPMFBroadcasting(self): with self.test_session(): # 1 x 2 x 2 diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py index 7922fb0606c6f4b475b25da716d5f9a169e213b5..1b9edcc85a7581de1cb1bd93fdbb9d47b8d1b84a 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py @@ -17,6 +17,9 @@ from __future__ import division from __future__ import print_function import numpy as np + +from tensorflow.python.eager import backprop +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -250,10 +253,10 @@ class DirichletMultinomialTest(test.TestCase): dist.variance(), dist.stddev(), ]) - self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) - self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.05) - self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) + self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) + self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.) + self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.) + self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testCovariance(self): # Shape [2] @@ -442,7 +445,7 @@ class DirichletMultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) self.assertAllClose( actual_covariance_, sample_covariance_, atol=0., rtol=0.20) @@ -470,10 +473,25 @@ class DirichletMultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) self.assertAllEqual([4, 4], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.15) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) + + def testNotReparameterized(self): + total_count = constant_op.constant(5.0) + concentration = constant_op.constant([0.1, 0.1, 0.1]) + with backprop.GradientTape() as tape: + tape.watch(total_count) + tape.watch(concentration) + dist = ds.DirichletMultinomial( + total_count=total_count, + concentration=concentration) + samples = dist.sample(100) + grad_total_count, grad_concentration = tape.gradient( + samples, [total_count, concentration]) + self.assertIsNone(grad_total_count) + self.assertIsNone(grad_concentration) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index bcec6ef610d0389f4b0f164ff4ab1a1cd1f6d1e5..67ed0447ede39d7f0738c8caf3cc665bcfe5fd0b 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -20,6 +20,7 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -190,10 +191,10 @@ class DirichletTest(test.TestCase): dist.stddev(), ]) - self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) - self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03) - self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) + self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) + self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.) + self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.) + self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testVariance(self): with self.test_session(): @@ -264,6 +265,15 @@ class DirichletTest(test.TestCase): a=1., b=2.).cdf)[0], 0.01) + def testDirichletFullyReparameterized(self): + alpha = constant_op.constant([1.0, 2.0, 3.0]) + with backprop.GradientTape() as tape: + tape.watch(alpha) + dirichlet = dirichlet_lib.Dirichlet(alpha) + samples = dirichlet.sample(100) + grad_alpha = tape.gradient(samples, alpha) + self.assertIsNotNone(grad_alpha) + def testDirichletDirichletKL(self): conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py index ebcd41b0e24ae8093752c84cf5077029f2ac9330..850da3e9697ab5f087761e9988094a3015636c36 100644 --- a/tensorflow/python/kernel_tests/distributions/exponential_test.py +++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py @@ -23,6 +23,7 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import nn_ops @@ -163,6 +164,15 @@ class ExponentialTest(test.TestCase): stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) + def testFullyReparameterized(self): + lam = constant_op.constant([0.1, 1.0]) + with backprop.GradientTape() as tape: + tape.watch(lam) + exponential = exponential_lib.Exponential(rate=lam) + samples = exponential.sample(100) + grad_lam = tape.gradient(samples, lam) + self.assertIsNotNone(grad_lam) + def testExponentialWithSoftplusRate(self): with self.test_session(): lam = [-2.2, -3.4] diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py index 5e4813ac0762d2855d7fbe6754fe1466c29c06c9..297e20264c6d36f5b9098005393302337e3d1315 100644 --- a/tensorflow/python/kernel_tests/distributions/gamma_test.py +++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py @@ -21,9 +21,10 @@ import importlib import numpy as np -from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import gamma as gamma_lib @@ -45,6 +46,7 @@ special = try_import("scipy.special") stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class GammaTest(test.TestCase): def testGammaShape(self): @@ -53,9 +55,9 @@ class GammaTest(test.TestCase): beta = constant_op.constant(11.0) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - self.assertEqual(gamma.batch_shape_tensor().eval(), (5,)) + self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(gamma.event_shape_tensor().eval(), []) + self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) def testGammaLogPDF(self): @@ -74,8 +76,8 @@ class GammaTest(test.TestCase): if not stats: return expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(log_pdf.eval(), expected_log_pdf) - self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testGammaLogPDFMultidimensional(self): with self.test_session(): @@ -87,10 +89,10 @@ class GammaTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) log_pdf = gamma.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = gamma.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: return @@ -108,10 +110,10 @@ class GammaTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) log_pdf = gamma.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = gamma.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: @@ -135,7 +137,7 @@ class GammaTest(test.TestCase): if not stats: return expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(cdf.eval(), expected_cdf) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testGammaMean(self): with self.test_session(): @@ -146,7 +148,7 @@ class GammaTest(test.TestCase): if not stats: return expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.mean().eval(), expected_means) + self.assertAllClose(self.evaluate(gamma.mean()), expected_means) def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): with self.test_session(): @@ -155,7 +157,7 @@ class GammaTest(test.TestCase): gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) expected_modes = (alpha_v - 1) / beta_v self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(gamma.mode().eval(), expected_modes) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): with self.test_session(): @@ -166,7 +168,7 @@ class GammaTest(test.TestCase): rate=beta_v, allow_nan_stats=False) with self.assertRaisesOpError("x < y"): - gamma.mode().eval() + self.evaluate(gamma.mode()) def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): with self.test_session(): @@ -179,7 +181,7 @@ class GammaTest(test.TestCase): expected_modes = (alpha_v - 1) / beta_v expected_modes[0] = np.nan self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(gamma.mode().eval(), expected_modes) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaVariance(self): with self.test_session(): @@ -190,7 +192,7 @@ class GammaTest(test.TestCase): if not stats: return expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.variance().eval(), expected_variances) + self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) def testGammaStd(self): with self.test_session(): @@ -201,7 +203,7 @@ class GammaTest(test.TestCase): if not stats: return expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) - self.assertAllClose(gamma.stddev().eval(), expected_stddev) + self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) def testGammaEntropy(self): with self.test_session(): @@ -212,10 +214,10 @@ class GammaTest(test.TestCase): if not stats: return expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.entropy().eval(), expected_entropy) + self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) def testGammaSampleSmallAlpha(self): - with session.Session(): + with self.test_session(): alpha_v = 0.05 beta_v = 1.0 alpha = constant_op.constant(alpha_v) @@ -223,7 +225,7 @@ class GammaTest(test.TestCase): n = 100000 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(sample_values.shape, (n,)) self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) @@ -240,7 +242,7 @@ class GammaTest(test.TestCase): atol=.15) def testGammaSample(self): - with session.Session(): + with self.test_session(): alpha_v = 4.0 beta_v = 3.0 alpha = constant_op.constant(alpha_v) @@ -248,7 +250,7 @@ class GammaTest(test.TestCase): n = 100000 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(sample_values.shape, (n,)) self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) @@ -264,14 +266,26 @@ class GammaTest(test.TestCase): stats.gamma.var(alpha_v, scale=1 / beta_v), atol=.15) + def testGammaFullyReparameterized(self): + alpha = constant_op.constant(4.0) + beta = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(alpha) + tape.watch(beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + samples = gamma.sample(100) + grad_alpha, grad_beta = tape.gradient(samples, [alpha, beta]) + self.assertIsNotNone(grad_alpha) + self.assertIsNotNone(grad_beta) + def testGammaSampleMultiDimensional(self): - with session.Session(): + with self.test_session(): alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) n = 10000 samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n, 10, 100)) self.assertEqual(sample_values.shape, (n, 10, 100)) zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 @@ -283,11 +297,11 @@ class GammaTest(test.TestCase): sample_values.mean(axis=0), stats.gamma.mean( alpha_bc, scale=1 / beta_bc), - rtol=.035) + atol=0., rtol=.05) self.assertAllClose( sample_values.var(axis=0), stats.gamma.var(alpha_bc, scale=1 / beta_bc), - atol=4.5) + atol=10.0, rtol=0.) fails = 0 trials = 0 for ai, a in enumerate(np.reshape(alpha_v, [-1])): @@ -306,12 +320,12 @@ class GammaTest(test.TestCase): return ks < 0.02 def testGammaPdfOfSampleMultiDims(self): - with session.Session() as sess: + with self.test_session(): gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) num = 50000 samples = gamma.sample(num, seed=137) pdfs = gamma.prob(samples) - sample_vals, pdf_vals = sess.run([samples, pdfs]) + sample_vals, pdf_vals = self.evaluate([samples, pdfs]) self.assertEqual(samples.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2)) self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) @@ -345,18 +359,18 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = constant_op.constant(0.0, name="alpha") beta_v = constant_op.constant(1.0, name="beta") - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - with self.assertRaisesOpError("alpha"): - gamma.mean().eval() + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) + self.evaluate(gamma.mean()) alpha_v = constant_op.constant(1.0, name="alpha") beta_v = constant_op.constant(0.0, name="beta") - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - with self.assertRaisesOpError("beta"): - gamma.mean().eval() + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) + self.evaluate(gamma.mean()) def testGammaWithSoftplusConcentrationRate(self): with self.test_session(): @@ -364,10 +378,10 @@ class GammaTest(test.TestCase): beta_v = constant_op.constant([1.0, -3.6], name="beta") gamma = gamma_lib.GammaWithSoftplusConcentrationRate( concentration=alpha_v, rate=beta_v) - self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), - gamma.concentration.eval()) - self.assertAllEqual(nn_ops.softplus(beta_v).eval(), - gamma.rate.eval()) + self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)), + self.evaluate(gamma.concentration)) + self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)), + self.evaluate(gamma.rate)) def testGammaGammaKL(self): alpha0 = np.array([3.]) @@ -377,15 +391,15 @@ class GammaTest(test.TestCase): beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. - with self.test_session() as sess: + with self.test_session(): g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl_divergence(g0, g1) - # Execute graph. - [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) + # Execute graph. + [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) self.assertEqual(beta0.shape, kl_actual.get_shape()) @@ -399,7 +413,7 @@ class GammaTest(test.TestCase): + alpha0 * (beta1 / beta0 - 1.)) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) - self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2) + self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py index 918c7f63f2065525338632ba68cb180c7c50dea6..24b243f647e495c47d57f914951263e3ee4ca7a5 100644 --- a/tensorflow/python/kernel_tests/distributions/laplace_test.py +++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py @@ -22,6 +22,7 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -255,6 +256,18 @@ class LaplaceTest(test.TestCase): atol=0.) self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) + def testLaplaceFullyReparameterized(self): + loc = constant_op.constant(4.0) + scale = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(loc) + tape.watch(scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + samples = laplace.sample(100) + grad_loc, grad_scale = tape.gradient(samples, [loc, scale]) + self.assertIsNotNone(grad_loc) + self.assertIsNotNone(grad_scale) + def testLaplaceSampleMultiDimensional(self): with session.Session(): loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py index e24e8ade73a7ad762c877214f5ec3ee0848863fe..bfd40ba2b7a5d32e957507b36d44e1198bd3867f 100644 --- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py @@ -18,6 +18,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import backprop +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -310,10 +312,10 @@ class MultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.10) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) def testSampleUnbiasedScalarBatch(self): with self.test_session() as sess: @@ -338,10 +340,24 @@ class MultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) self.assertAllEqual([4, 4], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.10) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) + + def testNotReparameterized(self): + total_count = constant_op.constant(5.0) + p = constant_op.constant([0.2, 0.6]) + with backprop.GradientTape() as tape: + tape.watch(total_count) + tape.watch(p) + dist = multinomial.Multinomial( + total_count=total_count, + probs=p) + samples = dist.sample(100) + grad_total_count, grad_p = tape.gradient(samples, [total_count, p]) + self.assertIsNone(grad_total_count) + self.assertIsNone(grad_p) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py index d793e03272909cc97543e313041b6ae7f487ae3f..7ff48c0c10f4d2cd18072a22cdcef0fefc530eae 100644 --- a/tensorflow/python/kernel_tests/distributions/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -23,6 +23,7 @@ import math import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -77,20 +78,20 @@ class NormalTest(test.TestCase): self.assertEqual(expected, mu_shape) self.assertEqual(expected, sigma_shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testParamShapes(self): sample_shape = [10, 3, 4] self._testParamShapes(sample_shape, sample_shape) self._testParamShapes(constant_op.constant(sample_shape), sample_shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testParamStaticShapes(self): sample_shape = [10, 3, 4] self._testParamStaticShapes(sample_shape, sample_shape) self._testParamStaticShapes( tensor_shape.TensorShape(sample_shape), sample_shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalWithSoftplusScale(self): with self.test_session(): mu = array_ops.zeros((10, 3)) @@ -100,7 +101,7 @@ class NormalTest(test.TestCase): self.assertAllEqual( self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogPDF(self): with self.test_session(): batch_size = 6 @@ -134,7 +135,7 @@ class NormalTest(test.TestCase): self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogPDFMultidimensional(self): with self.test_session(): batch_size = 6 @@ -172,7 +173,7 @@ class NormalTest(test.TestCase): self.assertAllClose(expected_log_pdf, log_pdf_values) self.assertAllClose(np.exp(expected_log_pdf), pdf_values) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalCDF(self): with self.test_session(): batch_size = 50 @@ -194,7 +195,7 @@ class NormalTest(test.TestCase): expected_cdf = stats.norm(mu, sigma).cdf(x) self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalSurvivalFunction(self): with self.test_session(): batch_size = 50 @@ -217,7 +218,7 @@ class NormalTest(test.TestCase): expected_sf = stats.norm(mu, sigma).sf(x) self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogCDF(self): with self.test_session(): batch_size = 50 @@ -239,7 +240,7 @@ class NormalTest(test.TestCase): if not stats: return expected_cdf = stats.norm(mu, sigma).logcdf(x) - self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-5) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) def testFiniteGradientAtDifficultPoints(self): for dtype in [np.float32, np.float64]: @@ -261,7 +262,7 @@ class NormalTest(test.TestCase): self.assertAllFinite(grads[0]) self.assertAllFinite(grads[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogSurvivalFunction(self): with self.test_session(): batch_size = 50 @@ -285,7 +286,7 @@ class NormalTest(test.TestCase): expected_sf = stats.norm(mu, sigma).logsf(x) self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalEntropyWithScalarInputs(self): # Scipy.stats.norm cannot deal with the shapes in the other test. with self.test_session(): @@ -307,7 +308,7 @@ class NormalTest(test.TestCase): expected_entropy = stats.norm(mu_v, sigma_v).entropy() self.assertAllClose(expected_entropy, self.evaluate(entropy)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalEntropy(self): with self.test_session(): mu_v = np.array([1.0, 1.0, 1.0]) @@ -328,7 +329,7 @@ class NormalTest(test.TestCase): self.assertAllEqual(normal.batch_shape, entropy.get_shape()) self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalMeanAndMode(self): with self.test_session(): # Mu will be broadcast to [7, 7, 7]. @@ -343,7 +344,7 @@ class NormalTest(test.TestCase): self.assertAllEqual((3,), normal.mode().get_shape()) self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalQuantile(self): with self.test_session(): batch_size = 52 @@ -395,7 +396,7 @@ class NormalTest(test.TestCase): def testQuantileFiniteGradientAtDifficultPointsFloat64(self): self._baseQuantileFiniteGradientAtDifficultPoints(np.float64) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalVariance(self): with self.test_session(): # sigma will be broadcast to [7, 7, 7] @@ -407,7 +408,7 @@ class NormalTest(test.TestCase): self.assertAllEqual((3,), normal.variance().get_shape()) self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalStandardDeviation(self): with self.test_session(): # sigma will be broadcast to [7, 7, 7] @@ -419,7 +420,7 @@ class NormalTest(test.TestCase): self.assertAllEqual((3,), normal.stddev().get_shape()) self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalSample(self): with self.test_session(): mu = constant_op.constant(3.0) @@ -453,7 +454,19 @@ class NormalTest(test.TestCase): self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) - @test_util.run_in_graph_and_eager_modes() + def testNormalFullyReparameterized(self): + mu = constant_op.constant(4.0) + sigma = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(mu) + tape.watch(sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) + samples = normal.sample(100) + grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma]) + self.assertIsNotNone(grad_mu) + self.assertIsNotNone(grad_sigma) + + @test_util.run_in_graph_and_eager_modes def testNormalSampleMultiDimensional(self): with self.test_session(): batch_size = 2 @@ -489,7 +502,7 @@ class NormalTest(test.TestCase): self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNegativeSigmaFails(self): with self.test_session(): with self.assertRaisesOpError("Condition x > 0 did not hold"): @@ -497,7 +510,7 @@ class NormalTest(test.TestCase): loc=[1.], scale=[-5.], validate_args=True, name="G") self.evaluate(normal.mean()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalShape(self): with self.test_session(): mu = constant_op.constant([-3.0] * 5) @@ -524,7 +537,7 @@ class NormalTest(test.TestCase): feed_dict={mu: 5.0, sigma: [1.0, 2.0]}), [2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size) diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py index 4565bf5c4669b4d416049816046f6f8ed187270d..a634194ce5293f4d7e7a68aa661080ed06493297 100644 --- a/tensorflow/python/kernel_tests/distributions/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -89,7 +89,7 @@ class NdtriTest(test.TestCase): all_true = np.ones_like(is_finite, dtype=np.bool) self.assertAllEqual(all_true, is_finite) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtri(self): """Verifies that ndtri computation is correct.""" with self.test_session(): @@ -138,11 +138,11 @@ class NdtriTest(test.TestCase): lambda x: special_math.ndtri(x), p) # pylint: disable=unnecessary-lambda self.assertAllFinite(self.evaluate(grads[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtriFiniteGradientFloat32(self): self._baseNdtriFiniteGradientTest(np.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtriFiniteGradientFloat64(self): self._baseNdtriFiniteGradientTest(np.float64) diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py index a4fdb658e857d832d5bf69485bbfb2517646a7b7..05590542efe2623e608f783233db68240331ba20 100644 --- a/tensorflow/python/kernel_tests/distributions/student_t_test.py +++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py @@ -23,6 +23,7 @@ import math import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util @@ -172,11 +173,11 @@ class StudentTTest(test.TestCase): sample_values = self.evaluate(samples) n_val = 200000 self.assertEqual(sample_values.shape, (n_val,)) - self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0) + self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) self.assertAllClose( sample_values.var(), sigma_v**2 * df_v / (df_v - 2), - rtol=1e-2, + rtol=0.1, atol=0) self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) @@ -215,11 +216,11 @@ class StudentTTest(test.TestCase): def testStudentSampleMultiDimensional(self): with self.test_session(): batch_size = 7 - df = constant_op.constant([[3., 7.]] * batch_size) + df = constant_op.constant([[5., 7.]] * batch_size) mu = constant_op.constant([[3., -3.]] * batch_size) sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] * batch_size) - df_v = [3., 7.] + df_v = [5., 7.] mu_v = [3., -3.] sigma_v = [np.sqrt(10.), np.sqrt(15.)] n = constant_op.constant(200000) @@ -228,21 +229,21 @@ class StudentTTest(test.TestCase): sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) self.assertAllClose( - sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0) + sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) self.assertAllClose( sample_values[:, 0, 0].var(), sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), - rtol=1e-1, + rtol=0.2, atol=0) self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) self.assertAllClose( - sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0) + sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) self.assertAllClose( sample_values[:, 0, 1].var(), sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), - rtol=1e-1, + rtol=0.2, atol=0) - self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1]) + self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) def _checkKLApprox(self, df, mu, sigma, samples): n = samples.size @@ -272,7 +273,7 @@ class StudentTTest(test.TestCase): self.assertEqual(student.entropy().get_shape(), (3,)) self.assertEqual(student.log_prob(2.).get_shape(), (3,)) self.assertEqual(student.prob(2.).get_shape(), (3,)) - self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,)) + self.assertEqual(student.sample(37).get_shape(), (37, 3,)) _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) @@ -445,15 +446,30 @@ class StudentTTest(test.TestCase): self.assertEqual(samples.get_shape(), (num,)) self.assertEqual(pdfs.get_shape(), (num,)) self.assertEqual(mean.get_shape(), ()) - self.assertNear(np.pi, np.mean(sample_vals), err=0.02) + self.assertNear(np.pi, np.mean(sample_vals), err=0.1) self.assertNear(np.pi, mean_val, err=1e-6) # Verify integral over sample*pdf ~= 1. # Tolerance increased since eager was getting a value of 1.002041. - self._assertIntegral(sample_vals, pdf_vals, err=3e-3) + self._assertIntegral(sample_vals, pdf_vals, err=5e-2) if not stats: return self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) + def testFullyReparameterized(self): + df = constant_op.constant(2.0) + mu = constant_op.constant(1.0) + sigma = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(df) + tape.watch(mu) + tape.watch(sigma) + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + samples = student.sample(100) + grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma]) + self.assertIsNotNone(grad_df) + self.assertIsNotNone(grad_mu) + self.assertIsNotNone(grad_sigma) + def testPdfOfSampleMultiDims(self): student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) self.assertAllEqual([], student.event_shape) @@ -466,22 +482,22 @@ class StudentTTest(test.TestCase): sample_vals, pdf_vals = self.evaluate([samples, pdfs]) self.assertEqual(samples.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2)) - self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03) - self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03) - self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) - self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) - self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) - self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) + self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1) + self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1) + self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05) + self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05) + self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05) + self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05) if not stats: return self.assertNear( stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var np.var(sample_vals[:, :, 0]), - err=.4) + err=1.0) self.assertNear( stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var np.var(sample_vals[:, :, 1]), - err=.4) + err=1.0) def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3): s_p = zip(sample_vals, pdf_vals) diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py index e74051c9013b7d51914868e66022546ae8862b60..bc9c267b9a5eac6fd8c9c4290dcc4b56865ddb50 100644 --- a/tensorflow/python/kernel_tests/distributions/uniform_test.py +++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py @@ -22,6 +22,7 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape @@ -47,7 +48,7 @@ stats = try_import("scipy.stats") class UniformTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformRange(self): with self.test_session(): a = 3.0 @@ -57,7 +58,7 @@ class UniformTest(test.TestCase): self.assertAllClose(b, self.evaluate(uniform.high)) self.assertAllClose(b - a, self.evaluate(uniform.range())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformPDF(self): with self.test_session(): a = constant_op.constant([-3.0] * 5 + [15.0]) @@ -83,7 +84,7 @@ class UniformTest(test.TestCase): log_pdf = uniform.log_prob(x) self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformShape(self): with self.test_session(): a = constant_op.constant([-3.0] * 5) @@ -95,7 +96,7 @@ class UniformTest(test.TestCase): self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), []) self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformPDFWithScalarEndpoint(self): with self.test_session(): a = constant_op.constant([0.0, 5.0]) @@ -108,7 +109,7 @@ class UniformTest(test.TestCase): pdf = uniform.prob(x) self.assertAllClose(expected_pdf, self.evaluate(pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformCDF(self): with self.test_session(): batch_size = 6 @@ -132,7 +133,7 @@ class UniformTest(test.TestCase): log_cdf = uniform.log_cdf(x) self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformEntropy(self): with self.test_session(): a_v = np.array([1.0, 1.0, 1.0]) @@ -142,7 +143,7 @@ class UniformTest(test.TestCase): expected_entropy = np.log(b_v - a_v) self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformAssertMaxGtMin(self): with self.test_session(): a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) @@ -153,7 +154,7 @@ class UniformTest(test.TestCase): uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) self.evaluate(uniform.low) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformSample(self): with self.test_session(): a = constant_op.constant([3.0, 4.0]) @@ -168,15 +169,15 @@ class UniformTest(test.TestCase): sample_values = self.evaluate(samples) self.assertEqual(sample_values.shape, (100000, 2)) self.assertAllClose( - sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-2) + sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.) self.assertAllClose( - sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-2) + sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.) self.assertFalse( np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v)) self.assertFalse( np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def _testUniformSampleMultiDimensional(self): # DISABLED: Please enable this test once b/issues/30149644 is resolved. with self.test_session(): @@ -207,7 +208,7 @@ class UniformTest(test.TestCase): self.assertAllClose( sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformMean(self): with self.test_session(): a = 10.0 @@ -218,7 +219,7 @@ class UniformTest(test.TestCase): s_uniform = stats.uniform(loc=a, scale=b - a) self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformVariance(self): with self.test_session(): a = 10.0 @@ -229,7 +230,7 @@ class UniformTest(test.TestCase): s_uniform = stats.uniform(loc=a, scale=b - a) self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformStd(self): with self.test_session(): a = 10.0 @@ -240,7 +241,7 @@ class UniformTest(test.TestCase): s_uniform = stats.uniform(loc=a, scale=b - a) self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformNans(self): with self.test_session(): a = 10.0 @@ -258,7 +259,7 @@ class UniformTest(test.TestCase): self.assertFalse(is_nan[0]) self.assertTrue(is_nan[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformSamplePdf(self): with self.test_session(): a = 10.0 @@ -268,7 +269,7 @@ class UniformTest(test.TestCase): self.evaluate( math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformBroadcasting(self): with self.test_session(): a = 10.0 @@ -279,7 +280,7 @@ class UniformTest(test.TestCase): expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]]) self.assertAllClose(expected_pdf, self.evaluate(pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformSampleWithShape(self): with self.test_session(): a = 10.0 @@ -299,6 +300,18 @@ class UniformTest(test.TestCase): expected_pdf = [1.0, 0.1] self.assertAllClose(expected_pdf, self.evaluate(pdf)) + def testFullyReparameterized(self): + a = constant_op.constant(0.1) + b = constant_op.constant(0.8) + with backprop.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + uniform = uniform_lib.Uniform(a, b) + samples = uniform.sample(100) + grad_a, grad_b = tape.gradient(samples, [a, b]) + self.assertIsNotNone(grad_a) + self.assertIsNotNone(grad_b) + # Eager doesn't pass due to a type mismatch in one of the ops. def testUniformFloat64(self): uniform = uniform_lib.Uniform( diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 2f256d3e8beac145a14ca1dd63f267fb5f4ef3a5..9d38ffcb4a963efb71153f59d6269ba84a5d1379 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -59,65 +59,6 @@ def _logit(x): class AssertCloseTest(test.TestCase): - def testAssertCloseIntegerDtype(self): - x = array_ops.placeholder(dtypes.int32) - y = x - z = array_ops.placeholder(dtypes.int32) - feed_dict = {x: [1, 5, 10, 15, 20], z: [2, 5, 10, 15, 20]} - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, y)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with ops.control_dependencies([du.assert_close(y, x)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, z)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - array_ops.identity(y).eval(feed_dict=feed_dict) - - def testAssertCloseNonIntegerDtype(self): - x = array_ops.placeholder(dtypes.float32) - y = x + 1e-8 - z = array_ops.placeholder(dtypes.float32) - feed_dict = {x: [1., 5, 10, 15, 20], z: [2., 5, 10, 15, 20]} - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, y)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with ops.control_dependencies([du.assert_close(y, x)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, z)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - array_ops.identity(y).eval(feed_dict=feed_dict) - - @test_util.run_in_graph_and_eager_modes() - def testAssertCloseEpsilon(self): - x = [0., 5, 10, 15, 20] - # x != y - y = [0.1, 5, 10, 15, 20] - # x = z - z = [1e-8, 5, 10, 15, 20] - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, z)]): - self.evaluate(array_ops.identity(x)) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, y)]): - self.evaluate(array_ops.identity(x)) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - self.evaluate(array_ops.identity(y)) - def testAssertIntegerForm(self): # This should only be detected as an integer. x = array_ops.placeholder(dtypes.float32) @@ -150,21 +91,21 @@ class AssertCloseTest(test.TestCase): class MaybeGetStaticTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetStaticInt(self): x = 2 self.assertEqual(x, du.maybe_get_static_value(x)) self.assertAllClose( np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetStaticNumpyArray(self): x = np.array(2, dtype=np.int32) self.assertEqual(x, du.maybe_get_static_value(x)) self.assertAllClose( np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetStaticConstant(self): x = constant_op.constant(2, dtype=dtypes.int32) self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x)) @@ -179,7 +120,7 @@ class MaybeGetStaticTest(test.TestCase): class GetLogitsAndProbsTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testImproperArguments(self): with self.test_session(): with self.assertRaises(ValueError): @@ -188,7 +129,7 @@ class GetLogitsAndProbsTest(test.TestCase): with self.assertRaises(ValueError): du.get_logits_and_probs(logits=[0.1], probs=[0.1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogits(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) logits = _logit(p) @@ -200,7 +141,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.) self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogitsMultidimensional(self): p = np.array([0.2, 0.3, 0.5], dtype=np.float32) logits = np.log(p) @@ -212,7 +153,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(self.evaluate(new_p), p) self.assertAllClose(self.evaluate(new_logits), logits) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbability(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) @@ -223,7 +164,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(_logit(p), self.evaluate(new_logits)) self.assertAllClose(p, self.evaluate(new_p)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) @@ -234,7 +175,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(np.log(p), self.evaluate(new_logits)) self.assertAllClose(p, self.evaluate(new_p)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgs(self): p = [0.01, 0.2, 0.5, 0.7, .99] # Component less than 0. @@ -265,7 +206,7 @@ class GetLogitsAndProbsTest(test.TestCase): probs=p3, validate_args=False) self.evaluate(prob) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgsMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) # Component less than 0. Still sums to 1. @@ -367,7 +308,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): param) checked_param.eval(feed_dict={param: np.ones([int(2**11+1)])}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUnsupportedDtype(self): with self.test_session(): with self.assertRaises(TypeError): @@ -552,7 +493,7 @@ class RotateTransposeTest(test.TestCase): x = np.array(x) return np.transpose(x, np.roll(np.arange(len(x.shape)), shift)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRollStatic(self): with self.test_session(): if context.executing_eagerly(): diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index 14a336c6881960cbf03ba767a051835feebf9d04..9e7b5283381dd7bc0725e1ab6fb9d7d13153f02d 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -126,14 +126,14 @@ class FIFOQueueTest(test.TestCase): q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run() self.assertEqual(4, q.size().eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMultipleDequeues(self): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue_many([[1, 2, 3]])) a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testQueuesDontShare(self): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue(1)) diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index facadc971ff516e4f9edea0c4f52ab0953ec5fce..bfd4a8fd49c22950cc2d0f0117ca635fbdcb6caa 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -56,7 +56,7 @@ def simple_scoped_fn(a, x): class FunctionalOpsTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_Simple(self): with self.test_session(): elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -72,7 +72,7 @@ class FunctionalOpsTest(test.TestCase): initializer=10) self.assertAllEqual(880, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -83,7 +83,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(22, r_value[0]) self.assertAllEqual(20, r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -111,7 +111,7 @@ class FunctionalOpsTest(test.TestCase): self.assertEqual(len(variables.trainable_variables()), 1) self.assertAllEqual(880, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_Simple(self): with self.test_session(): elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -127,7 +127,7 @@ class FunctionalOpsTest(test.TestCase): initializer=10) self.assertAllEqual(1282, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -138,7 +138,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(22, r_value[0]) self.assertAllEqual(20, r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -182,7 +182,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(720.0, self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_Simple(self): with self.test_session(): nums = [1, 2, 3, 4, 5, 6] @@ -202,7 +202,7 @@ class FunctionalOpsTest(test.TestCase): values=constant_op.constant([0, 1, 2]), dense_shape=[2, 2])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapOverScalarErrors(self): with self.assertRaisesRegexp(ValueError, "not scalars"): functional_ops.map_fn(lambda x: x, [1, 2]) @@ -251,7 +251,7 @@ class FunctionalOpsTest(test.TestCase): r = gradients_impl.gradients(y, elems)[0] self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_SimpleNotTensor(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -260,7 +260,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual( np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_SingleInputMultiOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -275,7 +275,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual((nums + 3) * 2, received[0]) self.assertAllEqual(-(nums + 3) * 2, received[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiOutputMismatchedDtype(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -287,7 +287,7 @@ class FunctionalOpsTest(test.TestCase): nums, dtype=[dtypes.int64, dtypes.int64]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSingleOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -298,7 +298,7 @@ class FunctionalOpsTest(test.TestCase): received = self.evaluate(r) self.assertAllEqual(nums * nums + (-nums), received) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSameStructureOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -313,7 +313,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(-nums, received[1]) self.assertAllEqual(nums, received[2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_Simple(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") @@ -328,7 +328,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_Reverse(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") @@ -345,7 +345,7 @@ class FunctionalOpsTest(test.TestCase): self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -357,7 +357,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -367,7 +367,7 @@ class FunctionalOpsTest(test.TestCase): (elems + 1, -elems), initializer) self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSameTypeOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -377,7 +377,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(np.cumsum(elems), r_value[0]) self.assertAllEqual(np.cumsum(-elems), r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiOutputMismatchedInitializer(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -408,7 +408,7 @@ class FunctionalOpsTest(test.TestCase): results = np.array([6, 16, 38, 84, 178, 368]) self.assertAllEqual(results, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScanFoldl_Nested(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") @@ -467,7 +467,7 @@ class FunctionalOpsTest(test.TestCase): variables.global_variables_initializer().run() sess.run(grad) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -479,7 +479,7 @@ class FunctionalOpsTest(test.TestCase): y = functional_ops.foldl(fn, x, initializer=initializer) self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -491,7 +491,7 @@ class FunctionalOpsTest(test.TestCase): y = functional_ops.map_fn(lambda e: e, x) self.assertIs(None, y.get_shape().dims) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapEmptyScalar(self): with self.test_session(): map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) @@ -507,7 +507,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([0, 3, 2], map_return.get_shape().dims) self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScanShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -604,6 +604,25 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, [6]) + def testRemoteFunctionSameDeviceDirectSession(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/cpu:0"): + a = variables.Variable(2, dtype=dtypes.int32) + b = variables.Variable(3, dtype=dtypes.int32) + + with ops.device("/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, [6]) + def testRemoteFunctionCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -652,6 +671,24 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, 9.0) + def testRemoteFunctionGPUCPUStrings(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @function.Defun(dtypes.string) + def _remote_fn(inp): + return array_ops.identity(inp) + + a = array_ops.constant("a") + + with ops.device("/gpu:0"): + remote_op = functional_ops.remote_call( + args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0") + + with self.test_session() as sess: + ret = sess.run(remote_op) + self.assertAllEqual(ret, [b"a"]) + def testRemoteFunctionCrossProcess(self): workers, _ = test_util.create_local_cluster(2, 1) diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 795aa67248f66e72f8f772845c4ca5b2b1b06d3d..927ca012ae6fc876364734c6f9bafd62ccc87467 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -364,14 +364,52 @@ class UniformUnitScalingInitializationTest(test.TestCase): class VarianceScalingInitializationTest(test.TestCase): + def testTruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='truncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + def testNormalDistribution(self): shape = [100, 100] expect_mean = 0. expect_var = 1. / shape[0] init = init_ops.variance_scaling_initializer(distribution='normal') - with self.test_session(use_gpu=True): + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + def testUntruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='untruncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'random_normal', wraps=random_ops.random_normal) \ + as mock_random_normal: x = init(shape).eval() + self.assertTrue(mock_random_normal.called) self.assertNear(np.mean(x), expect_mean, err=1e-2) self.assertNear(np.var(x), expect_var, err=1e-2) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 0123adc2c3e5c32fd86ef11e7b1f552964232abd..69d3aa401751f56ea338a5ac4b24d65e68dbddeb 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -107,6 +107,10 @@ cuda_py_test( "//tensorflow/python:random_ops", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -124,7 +128,10 @@ cuda_py_test( "//tensorflow/python:random_ops", ], shard_count = 5, - tags = ["optonly"], # Test is flaky without optimization. + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -141,6 +148,10 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -178,6 +189,10 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -214,4 +229,8 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py index 2b80f01b73441185281a3e2ef4db003b150c1e12..3ede2aceaa51c2795029ba13b763fed3e2ddc441 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py @@ -80,7 +80,7 @@ class SquareLinearOperatorBlockDiagTest( build_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]), ] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_blocks = ( build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ @@ -91,26 +91,19 @@ class SquareLinearOperatorBlockDiagTest( for block_shape in expected_blocks ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in expected_blocks - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = self.evaluate(matrices) - operator = block_diag.LinearOperatorBlockDiag( - [linalg.LinearOperatorFullMatrix( - m_ph, is_square=True) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = block_diag.LinearOperatorBlockDiag( - [linalg.LinearOperatorFullMatrix( - m, is_square=True) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix( + l, is_square=True) for l in lin_op_matrices]) + + # Should be auto-set. + self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) @@ -123,7 +116,7 @@ class SquareLinearOperatorBlockDiagTest( block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) - return operator, block_diag_dense, feed_dict + return operator, block_diag_dense def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index 5713d169696c78e996332b7a515a3ee2eedca839..7261d4bb3bc4aa24f51be21f9ac261549dca58d5 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -95,7 +95,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator( # real, the matrix will not be real. return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating real spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -107,22 +107,18 @@ class LinearOperatorCirculantTestSelfAdjointOperator( # zero, so the operator will still be self-adjoint. spectrum = math_ops.cast(spectrum, dtype) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, is_self_adjoint=True, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, is_self_adjoint=True, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -149,7 +145,7 @@ class LinearOperatorCirculantTestHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.float32, dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating Hermitian spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -172,22 +168,18 @@ class LinearOperatorCirculantTestHermitianSpectrum( spectrum = math_ops.fft(h_c) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -213,7 +205,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # Will be well conditioned enough to get accurate solves. spectrum = linear_operator_test_util.random_sign_uniform( @@ -222,22 +214,18 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( minval=1., maxval=2.) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -432,7 +420,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.float32, dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating Hermitian spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -455,22 +443,18 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( spectrum = math_ops.fft2d(h_c) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant2D( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant2D( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant2D( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat class LinearOperatorCirculant2DTestNonHermitianSpectrum( @@ -486,7 +470,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # Will be well conditioned enough to get accurate solves. spectrum = linear_operator_test_util.random_sign_uniform( @@ -495,22 +479,18 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( minval=1., maxval=2.) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant2D( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant2D( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant2D( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_real_hermitian_spectrum_gives_real_symmetric_operator(self): with self.test_session() as sess: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py index f96b9ccdaacae7d8e0552ed3d74ce53808fed963..612a50bcec771f8511d20d19b312a797d531f109 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py @@ -44,7 +44,7 @@ class SquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() shape = list(build_info.shape) @@ -56,33 +56,23 @@ class SquareLinearOperatorCompositionTest( for _ in range(num_operators) ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in range(num_operators) - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = sess.run(matrices) - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated each matrix to a numpy array. + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = linalg.LinearOperatorComposition( + [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices], + is_square=True) + matmul_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(matmul_order_list[0]) + mat = matmul_order_list[0] for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) - return operator, mat, feed_dict + return operator, mat def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. @@ -148,7 +138,7 @@ class NonSquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() shape = list(build_info.shape) @@ -170,30 +160,22 @@ class NonSquareLinearOperatorCompositionTest( shape_2, dtype=dtype) ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in range(num_operators) - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = sess.run(matrices) - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph]) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m) for m in matrices]) - feed_dict = None - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated each matrix to a numpy array. + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = linalg.LinearOperatorComposition( + [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices]) + matmul_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(matmul_order_list[0]) + mat = matmul_order_list[0] for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) - return operator, mat, feed_dict + return operator, mat def test_static_shapes(self): operators = [ diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py index 0a0e31c716ecfa10ed93cff92fa908a240f8495e..83cc8c483f9aec6dd0ddf3f961a8180af7515e40 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py @@ -34,25 +34,21 @@ class LinearOperatorDiagTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) diag = linear_operator_test_util.random_sign_uniform( shape[:-1], minval=1., maxval=2., dtype=dtype) + + lin_op_diag = diag + if use_placeholder: - diag_ph = array_ops.placeholder(dtype=dtype) - # Evaluate the diag here because (i) you cannot feed a tensor, and (ii) - # diag is random and we want the same value used for both mat and - # feed_dict. - diag = diag.eval() - operator = linalg.LinearOperatorDiag(diag_ph) - feed_dict = {diag_ph: diag} - else: - operator = linalg.LinearOperatorDiag(diag) - feed_dict = None + lin_op_diag = array_ops.placeholder_with_default(diag, shape=None) + + operator = linalg.LinearOperatorDiag(lin_op_diag) - mat = array_ops.matrix_diag(diag) + matrix = array_ops.matrix_diag(diag) - return operator, mat, feed_dict + return operator, matrix def test_assert_positive_definite_raises_for_zero_eigenvalue(self): # Matrix with one positive eigenvalue and one zero eigenvalue. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py index b3da623b5e8d8c99c6777e75e2d49f24dab1c96b..1a40a29ec6a040ca3d98e0b27492b1379d30cb4b 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -36,30 +35,20 @@ class SquareLinearOperatorFullMatrixTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_positive_definite_matrix( shape, dtype) + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix_ph, is_square=True) - feed_dict = {matrix_ph: matrix} - else: - # is_square should be auto-detected here. - operator = linalg.LinearOperatorFullMatrix(matrix) - feed_dict = None + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) - return operator, mat, feed_dict + return operator, matrix def test_is_x_flags(self): # Matrix with two positive eigenvalues. @@ -136,32 +125,20 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( def _dtypes_to_test(self): return [dtypes.float32, dtypes.float64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_positive_definite_matrix( shape, dtype, force_well_conditioned=True) + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - # is_square is auto-set because of self_adjoint/pd. - operator = linalg.LinearOperatorFullMatrix( - matrix_ph, is_self_adjoint=True, is_positive_definite=True) - feed_dict = {matrix_ph: matrix} - else: - operator = linalg.LinearOperatorFullMatrix( - matrix, is_self_adjoint=True, is_positive_definite=True) - feed_dict = None - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) - - return operator, mat, feed_dict + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) + + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) + + return operator, matrix def test_is_x_flags(self): # Matrix with two positive eigenvalues. @@ -210,26 +187,18 @@ class NonSquareLinearOperatorFullMatrixTest( linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_normal(shape, dtype=dtype) + + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix_ph) - feed_dict = {matrix_ph: matrix} - else: - operator = linalg.LinearOperatorFullMatrix(matrix) - feed_dict = None + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) - return operator, mat, feed_dict + return operator, matrix def test_is_x_flags(self): matrix = [[3., 2., 1.], [1., 1., 1.]] diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py index 59f63f949e96991193412d3574603e58a75cb6e5..35dcf4417c313f5cbc00c8b66b4c5d1f2e157212 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py @@ -43,7 +43,7 @@ class LinearOperatorIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) assert shape[-1] == shape[-2] @@ -54,13 +54,7 @@ class LinearOperatorIdentityTest( num_rows, batch_shape=batch_shape, dtype=dtype) mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype) - # Nothing to feed since LinearOperatorIdentity takes no Tensor args. - if use_placeholder: - feed_dict = {} - else: - feed_dict = None - - return operator, mat, feed_dict + return operator, mat def test_assert_positive_definite(self): with self.test_session(): @@ -261,7 +255,7 @@ class LinearOperatorScaledIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) assert shape[-1] == shape[-2] @@ -274,24 +268,23 @@ class LinearOperatorScaledIdentityTest( multiplier = linear_operator_test_util.random_sign_uniform( shape=batch_shape, minval=1., maxval=2., dtype=dtype) - operator = linalg_lib.LinearOperatorScaledIdentity(num_rows, multiplier) # Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args. + lin_op_multiplier = multiplier + if use_placeholder: - multiplier_ph = array_ops.placeholder(dtype=dtype) - multiplier = multiplier.eval() - operator = linalg_lib.LinearOperatorScaledIdentity( - num_rows, multiplier_ph) - feed_dict = {multiplier_ph: multiplier} - else: - feed_dict = None + lin_op_multiplier = array_ops.placeholder_with_default( + multiplier, shape=None) + + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows, lin_op_multiplier) multiplier_matrix = array_ops.expand_dims( array_ops.expand_dims(multiplier, -1), -1) - mat = multiplier_matrix * linalg_ops.eye( + matrix = multiplier_matrix * linalg_ops.eye( num_rows, batch_shape=batch_shape, dtype=dtype) - return operator, mat, feed_dict + return operator, matrix def test_assert_positive_definite_does_not_raise_when_positive(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py index 784c730bbc8179dd1302294b2d558e8a0c532c0c..e26b946151dd8ddb923e34352feb6b483f9752fc 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py @@ -101,7 +101,7 @@ class SquareLinearOperatorKroneckerTest( def _tests_to_skip(self): return ["det", "solve", "solve_with_broadcast"] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_factors = build_info.__dict__["factors"] matrices = [ @@ -110,26 +110,15 @@ class SquareLinearOperatorKroneckerTest( for block_shape in expected_factors ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in expected_factors - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = self.evaluate(matrices) - operator = kronecker.LinearOperatorKronecker( - [linalg.LinearOperatorFullMatrix( - m_ph, is_square=True) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = kronecker.LinearOperatorKronecker( - [linalg.LinearOperatorFullMatrix( - m, is_square=True) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) + lin_op_matrices = [ + array_ops.placeholder_with_default(m, shape=None) for m in matrices] + + operator = kronecker.LinearOperatorKronecker( + [linalg.LinearOperatorFullMatrix( + l, is_square=True) for l in lin_op_matrices]) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) @@ -138,7 +127,7 @@ class SquareLinearOperatorKroneckerTest( if not use_placeholder: kronecker_dense.set_shape(shape) - return operator, kronecker_dense, feed_dict + return operator, kronecker_dense def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py index 8095f6419ef0d9543339cf1f4ee9cd4783f852b9..34b35a4ffb878c63f851f2b31491e7bfa4057417 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py @@ -68,7 +68,7 @@ class BaseLinearOperatorLowRankUpdatetest(object): build_info((3, 4, 4)), build_info((2, 1, 4, 4))] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): # Recall A = L + UDV^H shape = list(build_info.shape) diag_shape = shape[:-1] @@ -80,17 +80,17 @@ class BaseLinearOperatorLowRankUpdatetest(object): # operator, with condition number as high as 1e4. base_diag = linear_operator_test_util.random_uniform( diag_shape, minval=1e-4, maxval=1., dtype=dtype) - base_diag_ph = array_ops.placeholder(dtype=dtype) + lin_op_base_diag = base_diag # U u = linear_operator_test_util.random_normal_correlated_columns( u_perturbation_shape, dtype=dtype) - u_ph = array_ops.placeholder(dtype=dtype) + lin_op_u = u # V v = linear_operator_test_util.random_normal_correlated_columns( u_perturbation_shape, dtype=dtype) - v_ph = array_ops.placeholder(dtype=dtype) + lin_op_v = v # D if self._is_diag_update_positive: @@ -99,42 +99,25 @@ class BaseLinearOperatorLowRankUpdatetest(object): else: diag_update = linear_operator_test_util.random_normal( diag_update_shape, stddev=1e-4, dtype=dtype) - diag_update_ph = array_ops.placeholder(dtype=dtype) + lin_op_diag_update = diag_update if use_placeholder: - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - base_diag = base_diag.eval() - u = u.eval() - v = v.eval() - diag_update = diag_update.eval() - - # In all cases, set base_operator to be positive definite. - base_operator = linalg.LinearOperatorDiag( - base_diag_ph, is_positive_definite=True) - - operator = linalg.LinearOperatorLowRankUpdate( - base_operator, - u=u_ph, - v=v_ph if self._use_v else None, - diag_update=diag_update_ph if self._use_diag_update else None, - is_diag_update_positive=self._is_diag_update_positive) - feed_dict = { - base_diag_ph: base_diag, - u_ph: u, - v_ph: v, - diag_update_ph: diag_update} - else: - base_operator = linalg.LinearOperatorDiag( - base_diag, is_positive_definite=True) - operator = linalg.LinearOperatorLowRankUpdate( - base_operator, - u, - v=v if self._use_v else None, - diag_update=diag_update if self._use_diag_update else None, - is_diag_update_positive=self._is_diag_update_positive) - feed_dict = None + lin_op_base_diag = array_ops.placeholder_with_default( + base_diag, shape=None) + lin_op_u = array_ops.placeholder_with_default(u, shape=None) + lin_op_v = array_ops.placeholder_with_default(v, shape=None) + lin_op_diag_update = array_ops.placeholder_with_default( + diag_update, shape=None) + + base_operator = linalg.LinearOperatorDiag( + lin_op_base_diag, is_positive_definite=True) + + operator = linalg.LinearOperatorLowRankUpdate( + base_operator, + lin_op_u, + v=lin_op_v if self._use_v else None, + diag_update=lin_op_diag_update if self._use_diag_update else None, + is_diag_update_positive=self._is_diag_update_positive) # The matrix representing L base_diag_mat = array_ops.matrix_diag(base_diag) @@ -146,28 +129,28 @@ class BaseLinearOperatorLowRankUpdatetest(object): if self._use_v and self._use_diag_update: # In this case, we have L + UDV^H and it isn't symmetric. expect_use_cholesky = False - mat = base_diag_mat + math_ops.matmul( + matrix = base_diag_mat + math_ops.matmul( u, math_ops.matmul(diag_update_mat, v, adjoint_b=True)) elif self._use_v: # In this case, we have L + UDV^H and it isn't symmetric. expect_use_cholesky = False - mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) + matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) elif self._use_diag_update: # In this case, we have L + UDU^H, which is PD if D > 0, since L > 0. expect_use_cholesky = self._is_diag_update_positive - mat = base_diag_mat + math_ops.matmul( + matrix = base_diag_mat + math_ops.matmul( u, math_ops.matmul(diag_update_mat, u, adjoint_b=True)) else: # In this case, we have L + UU^H, which is PD since L > 0. expect_use_cholesky = True - mat = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) + matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) if expect_use_cholesky: self.assertTrue(operator._use_cholesky) else: self.assertFalse(operator._use_cholesky) - return operator, mat, feed_dict + return operator, matrix class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py index a57d2f085e089fb913f09fdd9b07cf13aa7f3c35..167c6cacd1a5bbbaa70a7fdd236ddd70ea8cd4e8 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py @@ -38,28 +38,23 @@ class LinearOperatorLowerTriangularTest( # matrix_triangular_solve. return [dtypes.float32, dtypes.float64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) # Upper triangle will be nonzero, but ignored. # Use a diagonal that ensures this matrix is well conditioned. tril = linear_operator_test_util.random_tril_matrix( shape, dtype=dtype, force_well_conditioned=True, remove_upper=False) + lin_op_tril = tril + if use_placeholder: - tril_ph = array_ops.placeholder(dtype=dtype) - # Evaluate the tril here because (i) you cannot feed a tensor, and (ii) - # tril is random and we want the same value used for both mat and - # feed_dict. - tril = tril.eval() - operator = linalg.LinearOperatorLowerTriangular(tril_ph) - feed_dict = {tril_ph: tril} - else: - operator = linalg.LinearOperatorLowerTriangular(tril) - feed_dict = None + lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None) + + operator = linalg.LinearOperatorLowerTriangular(lin_op_tril) - mat = array_ops.matrix_band_part(tril, -1, 0) + matrix = array_ops.matrix_band_part(tril, -1, 0) - return operator, mat, feed_dict + return operator, matrix def test_assert_non_singular(self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 49855200c2427a88a4bd582c2ef786c38a6fa76a..bf82e08551e6a276b95bf77f7932c31d7a844a78 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -46,7 +46,7 @@ def scalar_shape(): @test_util.with_c_shapes class ListOpsTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPop(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) @@ -54,14 +54,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPopGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testPushPop() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStack(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) @@ -70,14 +70,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStackGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testStack() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorListFromTensor(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) @@ -87,14 +87,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual(self.evaluate(e), 1.0) self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFromTensorGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testTensorListFromTensor() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetItem(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) @@ -104,14 +104,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [3.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testGetSetItem() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUnknownShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=-1) @@ -122,7 +122,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCPUGPUCopy(self): if not context.num_gpus(): return @@ -140,7 +140,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.float32)[1]), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStack(self): with context.graph_mode(), self.test_session(): tl = list_ops.empty_tensor_list( @@ -152,7 +152,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)), [[1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackInLoop(self): with context.graph_mode(), self.test_session(): t1 = list_ops.empty_tensor_list( @@ -170,7 +170,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32) self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackSwitchDtype(self): with context.graph_mode(), self.test_session(): list_ = list_ops.empty_tensor_list( @@ -192,7 +192,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackInLoopSwitchDtype(self): with context.graph_mode(), self.test_session(): t1 = list_ops.empty_tensor_list( @@ -216,7 +216,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)]) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSerialize(self): # pylint: disable=g-import-not-at-top try: @@ -248,7 +248,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): worker_e = array_ops.identity(e) self.assertAllEqual(self.evaluate(worker_e), [2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPopGradients(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, @@ -260,7 +260,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): e = 2 * e self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStackFromTensorGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) @@ -272,7 +272,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): grad = tape.gradient(result, [c])[0] self.assertAllEqual(self.evaluate(grad), [2.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) @@ -288,14 +288,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0]) self.assertAllEqual(self.evaluate(grad_c2), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSetOutOfBounds(self): c = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceVariableScatterGather(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) @@ -319,7 +319,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): [[1.0, 2.0]] * 4) self.assertAllEqual(self.evaluate(updated_v_stacked), expected) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConcat(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) @@ -379,7 +379,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls, element_dtype=dtypes.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushBackBatch(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py index 28c85fa13ad100c38382d2b787ff965f9e3ca44e..e635a71c78484278b54bfc4de70e232834c37a0a 100644 --- a/tensorflow/python/kernel_tests/logging_ops_test.py +++ b/tensorflow/python/kernel_tests/logging_ops_test.py @@ -59,7 +59,7 @@ class LoggingOpsTest(test.TestCase): class PrintGradientTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPrintShape(self): inp = constant_op.constant(2.0, shape=[100, 32]) inp_printed = logging_ops.Print(inp, [inp]) diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 1123c20a165ba93bd380fa471a8be91f7005d7bb..87fc715783b972a20465827d697cf06637588154 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -118,6 +119,14 @@ class AbsoluteDifferenceLossTest(test.TestCase): with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerNoMemoryLeaked(self): + # This is a somewhat convoluted way of testing that nothing gets added to + # a global collection. + predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) + labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + losses.absolute_difference(labels, predictions) + class SoftmaxCrossEntropyLossTest(test.TestCase): @@ -246,6 +255,13 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(loss.eval(), 0.0, 3) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerNoMemoryLeaked(self): + logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32) + losses.sparse_softmax_cross_entropy(labels, logits) + def testAllCorrectInt64Labels(self): with self.test_session(): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 4239151070b35ac11511360bd85c5c2424d80f3d..50154a45a8b58f270509e404737c8650cbd2c5ff 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -460,7 +460,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(initial_size, script_ops._py_funcs.size()) # ----- Tests for eager_py_func ----- - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerSingleOutputInt32(self): a = array_ops.ones((3, 3), dtype=dtypes.int32) x = array_ops.ones((3, 1), dtype=dtypes.int32) @@ -468,7 +468,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllEqual(ret, [[3], [3], [3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerSingleOutputFloat32(self): with test_util.device(use_gpu=True): a = array_ops.ones((3, 3), dtype=dtypes.float32) @@ -477,7 +477,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerArrayOutput(self): with test_util.device(use_gpu=True): a = array_ops.ones((3, 3), dtype=dtypes.float32) @@ -487,7 +487,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerReturnNone(self): with test_util.device(use_gpu=True): def no_return_value(): @@ -500,7 +500,7 @@ class PyFuncTest(test.TestCase): else: self.assertIsNone(ret) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerPyFuncInDefun(self): with test_util.device(use_gpu=True): def wrapper(): @@ -512,7 +512,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(wrapped()) self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerExceptionHandling(self): with test_util.device(use_gpu=True): self._testExceptionHandling( @@ -531,7 +531,7 @@ class PyFuncTest(test.TestCase): self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerReturningVariableRaisesError(self): def return_variable(): return resource_variable_ops.ResourceVariable(0.0) @@ -542,7 +542,7 @@ class PyFuncTest(test.TestCase): return_variable, inp=[], Tout=dtypes.float32) self.evaluate(output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerGradientTape(self): def f(x): @@ -565,7 +565,7 @@ class PyFuncTest(test.TestCase): dy_dx = gradients_impl.gradients(y, x)[0] self.assertEqual(self.evaluate(dy_dx), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerGradientTapeMultipleArgs(self): def f(x, y): @@ -616,6 +616,25 @@ class PyFuncTest(test.TestCase): self.assertEqual(y, 1.0) self.assertEqual(dy_dx, 2.0) + def testEagerRespectsDevicePlacmentOfOp(self): + + def f(x): + return math_ops.square(x) + + def g(x): + return math_ops.add(x, x) + + with ops.device("/CPU:0"): + # Explicitly ask for the py_funcs to execute on CPU, even if + # a GPU is available. + x = array_ops.placeholder(dtypes.float32) + y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32) + z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(z, feed_dict={x: 3.0}) + self.assertEqual(output, 18.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index acd7566eec8e3fffd74db33234b03a0c87427a3e..3b3a28fc9a24104cc9032ab23dfc51e690d3ec94 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -107,6 +107,23 @@ cuda_py_test( tags = ["nozapfhahn"], ) +cuda_py_test( + name = "random_grad_test", + size = "small", + srcs = ["random_grad_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_grad", + "//tensorflow/python:random_ops", + ], +) + cuda_py_test( name = "random_poisson_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py index 051c7d86bf2342f15b587fc350bfbede7fae2285..bd64d61af8e793e71a319b6ac1af95bd7dd16a3d 100644 --- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py +++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py @@ -54,7 +54,7 @@ native_sampler = random_ops.multinomial class MultinomialTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSmallEntropy(self): random_seed.set_random_seed(1618) for output_dtype in [np.int32, np.int64]: diff --git a/tensorflow/python/kernel_tests/random/random_grad_test.py b/tensorflow/python/kernel_tests/random/random_grad_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d455b785bbf562fb41f30cab7e0bb723a7b894 --- /dev/null +++ b/tensorflow/python/kernel_tests/random/random_grad_test.py @@ -0,0 +1,240 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.random_grad.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_grad +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +class AddLeadingUnitDimensionsTest(test.TestCase): + + def testBasic(self): + ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 3) + self.assertAllEqual(ret.shape, [1, 1, 1, 3, 2, 1]) + + def testZeroExtraDimensions(self): + ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 0) + self.assertAllEqual(ret.shape, [3, 2, 1]) + + def testScalarInput(self): + ret = random_grad.add_leading_unit_dimensions(1.0, 2) + self.assertAllEqual(ret.shape, [1, 1]) + + def testUnknownShape(self): + x = array_ops.placeholder(dtypes.float32) + num_dimensions = array_ops.placeholder(dtypes.int32) + ret = random_grad.add_leading_unit_dimensions(x, num_dimensions) + with self.test_session() as sess: + ret_val = sess.run(ret, {x: np.ones([2, 2]), num_dimensions: 2}) + self.assertAllEqual(ret_val.shape, [1, 1, 2, 2]) + + +class RandomGammaGradTest(test.TestCase): + """Tests for derivative of a sample ~ Gamma(alpha, beta) wrt alpha and beta. + + The sample is an "implicit" function of alpha, beta and the independent random + noise u. The derivatives we are looking for are + d sample(alpha, beta, u) / dalpha (and dbeta). + + The derivative w.r.t. beta is computed by the standard automatic + differentiation, so we trust that it is computed correctly. + + The derivative w.r.t. alpha is computed by Eigen function, so we test it in + several ways. Unfortunately, the standard derivative checking by perturbing + the parameter is impossible here, because we cannot fix the value of u + in the random sampler. Instead, we compare the derivative for the given pair + of (sample, alpha) to the values computed in various ways, and also check + some statistical properties of the derivative. + """ + + def testGradientsShape(self): + shape = [2, 3] + alpha = array_ops.ones([2, 2]) + beta = array_ops.ones([1, 2]) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + self.assertAllEqual(grads_alpha.shape, alpha.shape) + self.assertAllEqual(grads_beta.shape, beta.shape) + + def testGradientsShapeWithOneSamplePerParameter(self): + shape = [] + alpha = array_ops.ones([2, 2]) + beta = array_ops.ones([1, 2]) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + self.assertAllEqual(grads_alpha.shape, alpha.shape) + self.assertAllEqual(grads_beta.shape, beta.shape) + + def testGradientsUnknownShape(self): + shape = array_ops.placeholder(dtypes.int32) + alpha = array_ops.placeholder(dtypes.float32) + beta = array_ops.placeholder(dtypes.float32) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + + alpha_val = np.ones([1, 2]) + beta_val = np.ones([2, 1]) + with self.test_session() as sess: + grads_alpha_val, grads_beta_val = sess.run( + [grads_alpha, grads_beta], + {alpha: alpha_val, beta: beta_val, shape: [2, 1]}) + self.assertAllEqual(grads_alpha_val.shape, alpha_val.shape) + self.assertAllEqual(grads_beta_val.shape, beta_val.shape) + + def _testCompareToExplicitDerivative(self, dtype): + """Compare to the explicit reparameterization derivative. + + Verifies that the computed derivative satisfies + dsample / dalpha = d igammainv(alpha, u) / dalpha, + where u = igamma(alpha, sample). + + Args: + dtype: TensorFlow dtype to perform the computations in. + """ + delta = 1e-3 + np_dtype = dtype.as_numpy_dtype + try: + from scipy import misc # pylint: disable=g-import-not-at-top + from scipy import special # pylint: disable=g-import-not-at-top + + alpha_val = np.logspace(-2, 3, dtype=np_dtype) + alpha = constant_op.constant(alpha_val) + sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype) + actual = gradients_impl.gradients(sample, alpha)[0] + + (sample_val, actual_val) = self.evaluate((sample, actual)) + + u = special.gammainc(alpha_val, sample_val) + expected_val = misc.derivative( + lambda alpha_prime: special.gammaincinv(alpha_prime, u), + alpha_val, dx=delta * alpha_val) + + self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3) + except ImportError as e: + tf_logging.warn("Cannot use special functions in a test: %s" % str(e)) + + def testCompareToExplicitDerivativeFloat(self): + self._testCompareToExplicitDerivative(dtypes.float32) + + def testCompareToExplicitDerivativeDouble(self): + self._testCompareToExplicitDerivative(dtypes.float64) + + def _testCompareToImplicitDerivative(self, dtype): + """Compare to the implicit reparameterization derivative. + + Let's derive the formula we compare to. + + Start from the fact that CDF maps a random variable to the Uniform + random variable: + igamma(alpha, sample) = u, where u ~ Uniform(0, 1). + + Apply d / dalpha to both sides: + d igamma(alpha, sample) / dalpha + + d igamma(alpha, sample) / dsample * dsample/dalpha = 0 + d igamma(alpha, sample) / dalpha + + d igamma(alpha, sample) / dsample * dsample / dalpha = 0 + dsample/dalpha = - (d igamma(alpha, sample) / dalpha) + / d igamma(alpha, sample) / dsample + + This is the equation (8) of https://arxiv.org/abs/1805.08498 + + Args: + dtype: TensorFlow dtype to perform the computations in. + """ + np_dtype = dtype.as_numpy_dtype + alpha = constant_op.constant(np.logspace(-2, 3, dtype=np_dtype)) + sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype) + actual = gradients_impl.gradients(sample, alpha)[0] + + sample_sg = array_ops.stop_gradient(sample) + cdf = math_ops.igamma(alpha, sample_sg) + dcdf_dalpha, dcdf_dsample = gradients_impl.gradients( + cdf, [alpha, sample_sg]) + # Numerically unstable due to division, do not try at home. + expected = -dcdf_dalpha / dcdf_dsample + + (actual_val, expected_val) = self.evaluate((actual, expected)) + + self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3) + + def testCompareToImplicitDerivativeFloat(self): + self._testCompareToImplicitDerivative(dtypes.float32) + + def testCompareToImplicitDerivativeDouble(self): + self._testCompareToImplicitDerivative(dtypes.float64) + + def testAverageAlphaGradient(self): + """Statistical test for the gradient. + + Using the equation (5) of https://arxiv.org/abs/1805.08498, we have + 1 = d/dalpha E_{sample ~ Gamma(alpha, 1)} sample + = E_{sample ~ Gamma(alpha, 1)} dsample/dalpha. + Here we verify that the rhs is fairly close to one. + The convergence speed is not great, so we use many samples and loose bounds. + """ + num_samples = 1000 + alpha = constant_op.constant([0.8, 1e1, 1e3], dtype=dtypes.float32) + sample = random_ops.random_gamma([num_samples], alpha) + # We need to average the gradients, which is equivalent to averaging the + # samples and then doing backprop. + mean_sample = math_ops.reduce_mean(sample, axis=0) + dsample_dalpha = gradients_impl.gradients(mean_sample, alpha)[0] + dsample_dalpha_val = self.evaluate(dsample_dalpha) + self.assertAllClose(dsample_dalpha_val, [1.0] * 3, atol=1e-1, rtol=1e-1) + + def testQuadraticLoss(self): + """Statistical test for the gradient. + + The equation (5) of https://arxiv.org/abs/1805.08498 says + d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample) + = E_{sample ~ Gamma(alpha, 1)} df(sample)/dalpha. + + Choose a quadratic loss function f(sample) = (sample - t)^2. + Then, the lhs can be computed analytically: + d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample) + = d/dalpha [ (alpha + alpha^2) - 2 * t * alpha + t^2 ] + = 1 + 2 * alpha - 2 * t. + + We compare the Monte-Carlo estimate of the expectation with the + true gradient. + """ + num_samples = 1000 + t = 0.3 + alpha = 0.5 + expected = 1 + 2 * alpha - 2 * t + + alpha = constant_op.constant(alpha) + sample = random_ops.random_gamma([num_samples], alpha, 1.0) + loss = math_ops.reduce_mean(math_ops.square(sample - t)) + dloss_dalpha = gradients_impl.gradients(loss, alpha)[0] + dloss_dalpha_val = self.evaluate(dloss_dalpha) + self.assertAllClose(expected, dloss_dalpha_val, atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 7be473a5e750d9d6880f112cb0ca89b3ae61a7fd..8e06e1abfb52244e8c1a9b4ed15a270f6048e028 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -25,8 +25,6 @@ import shutil import threading import zlib -import six - from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -703,228 +701,6 @@ class TFRecordReaderTest(TFCompressionTestCase): self.assertAllEqual(self._Record(i, j), v) -class TFRecordWriterTest(TFCompressionTestCase): - - def setUp(self): - super(TFRecordWriterTest, self).setUp() - - def _AssertFilesEqual(self, a, b, equal): - for an, bn in zip(a, b): - with open(an, "rb") as af, open(bn, "rb") as bf: - if equal: - self.assertEqual(af.read(), bf.read()) - else: - self.assertNotEqual(af.read(), bf.read()) - - def testWriteReadZLibFiles(self): - # Write uncompressed then compress manually. - options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) - files = self._CreateFiles(options, prefix="uncompressed") - zlib_files = [ - self._ZlibCompressFile(fn, "tfrecord_%s.z" % i) - for i, fn in enumerate(files) - ] - self._AssertFilesEqual(files, zlib_files, False) - - # Now write compressd and verify same. - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - compressed_files = self._CreateFiles(options, prefix="compressed") - self._AssertFilesEqual(compressed_files, zlib_files, True) - - # Decompress compress and verify same. - uncompressed_files = [ - self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i) - for i, fn in enumerate(compressed_files) - ] - self._AssertFilesEqual(uncompressed_files, files, True) - - def testWriteReadGzipFiles(self): - # Write uncompressed then compress manually. - options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) - files = self._CreateFiles(options, prefix="uncompressed") - gzip_files = [ - self._GzipCompressFile(fn, "tfrecord_%s.gz" % i) - for i, fn in enumerate(files) - ] - self._AssertFilesEqual(files, gzip_files, False) - - # Now write compressd and verify same. - options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) - compressed_files = self._CreateFiles(options, prefix="compressed") - - # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so - # compressed_files can't be compared with gzip_files - - # Decompress compress and verify same. - uncompressed_files = [ - self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i) - for i, fn in enumerate(compressed_files) - ] - self._AssertFilesEqual(uncompressed_files, files, True) - - -class TFRecordWriterZlibTest(TFCompressionTestCase): - - def testOneEpoch(self): - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - files = self._CreateFiles(options) - with self.test_session() as sess: - reader = io_ops.TFRecordReader(name="test_reader", options=options) - queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) - key, value = reader.read(queue) - - queue.enqueue_many([files]).run() - queue.close().run() - for i in range(self._num_files): - for j in range(self._num_records): - k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) - self.assertAllEqual(self._Record(i, j), v) - - with self.assertRaisesOpError("is closed and has insufficient elements " - "\\(requested 1, current size 0\\)"): - k, v = sess.run([key, value]) - - def testZLibFlushRecord(self): - fn = self._WriteRecordsToFile([b"small record"], "small_record") - with open(fn, "rb") as h: - buff = h.read() - - # creating more blocks and trailing blocks shouldn't break reads - compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) - - output = b"" - for c in buff: - if isinstance(c, int): - c = six.int2byte(c) - output += compressor.compress(c) - output += compressor.flush(zlib.Z_FULL_FLUSH) - - output += compressor.flush(zlib.Z_FULL_FLUSH) - output += compressor.flush(zlib.Z_FULL_FLUSH) - output += compressor.flush(zlib.Z_FINISH) - - # overwrite the original file with the compressed data - with open(fn, "wb") as h: - h.write(output) - - with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - reader = io_ops.TFRecordReader(name="test_reader", options=options) - queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=()) - key, value = reader.read(queue) - queue.enqueue(fn).run() - queue.close().run() - k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % fn)) - self.assertAllEqual(b"small record", v) - - def testZlibReadWrite(self): - """Verify that files produced are zlib compatible.""" - original = [b"foo", b"bar"] - fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord") - zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z") - - # read the compressed contents and verify. - actual = [] - for r in tf_record.tf_record_iterator( - zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): - actual.append(r) - self.assertEqual(actual, original) - - def testZlibReadWriteLarge(self): - """Verify that writing large contents also works.""" - - # Make it large (about 5MB) - original = [_TEXT * 10240] - fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") - zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") - - actual = [] - for r in tf_record.tf_record_iterator( - zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): - actual.append(r) - self.assertEqual(actual, original) - - def testGzipReadWrite(self): - """Verify that files produced are gzip compatible.""" - original = [b"foo", b"bar"] - fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") - gzfn = self._GzipCompressFile(fn, "tfrecord.gz") - - actual = [] - for r in tf_record.tf_record_iterator( - gzfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)): - actual.append(r) - self.assertEqual(actual, original) - - -class TFRecordIteratorTest(TFCompressionTestCase): - - def setUp(self): - super(TFRecordIteratorTest, self).setUp() - self._num_records = 7 - - def testIterator(self): - records = [self._Record(0, i) for i in range(self._num_records)] - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - fn = self._WriteRecordsToFile(records, "compressed_records", options) - - reader = tf_record.tf_record_iterator(fn, options) - for expected in records: - record = next(reader) - self.assertAllEqual(expected, record) - with self.assertRaises(StopIteration): - record = next(reader) - - def testWriteZlibRead(self): - """Verify compression with TFRecordWriter is zlib library compatible.""" - original = [b"foo", b"bar"] - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z", - options) - - zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") - actual = list(tf_record.tf_record_iterator(zfn)) - self.assertEqual(actual, original) - - def testWriteZlibReadLarge(self): - """Verify compression for large records is zlib library compatible.""" - # Make it large (about 5MB) - original = [_TEXT * 10240] - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z", - options) - zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord") - actual = list(tf_record.tf_record_iterator(zfn)) - self.assertEqual(actual, original) - - def testWriteGzipRead(self): - original = [b"foo", b"bar"] - options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) - fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz", - options) - - gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord") - actual = list(tf_record.tf_record_iterator(gzfn)) - self.assertEqual(actual, original) - - def testBadFile(self): - """Verify that tf_record_iterator throws an exception on bad TFRecords.""" - fn = os.path.join(self.get_temp_dir(), "bad_file") - with tf_record.TFRecordWriter(fn) as writer: - writer.write(b"123") - fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") - with open(fn, "rb") as f: - with open(fn_truncated, "wb") as f2: - # DataLossError requires that we've written the header, so this must - # be at least 12 bytes. - f2.write(f.read(14)) - with self.assertRaises(errors_impl.DataLossError): - for _ in tf_record.tf_record_iterator(fn_truncated): - pass - - class AsyncReaderTest(test.TestCase): def testNoDeadlockFromQueue(self): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 5267eabf0e4f779a96840069a858609059554c89..0fb0b8895cbc847639999ad1bd23e7fb04c86034 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -145,7 +145,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertIn("", str(handle)) self.assertIn("", repr(handle)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDtypeSurvivesIdentity(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) @@ -156,7 +156,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(1.0) self.assertNotEqual(v.name, v.assign_add(1.0).name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateRead(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( @@ -165,7 +165,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertAllEqual(1, value) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testManyAssigns(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.assign_variable_op( @@ -183,7 +183,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(f, 1) self.assertEqual(s, 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignAdd(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( @@ -194,7 +194,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertEqual(read, 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterAdd(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -207,7 +207,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterSub(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -220,7 +220,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[-1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMul(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -233,7 +233,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterDiv(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -246,7 +246,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMin(self): with ops.device("cpu:0"): handle = resource_variable_ops.var_handle_op( @@ -283,7 +283,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): meta_graph_two = saver.export_meta_graph(graph=graph) self.assertEqual(meta_graph_def, meta_graph_two) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMax(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -296,7 +296,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[6]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterAddScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -309,7 +309,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterSubScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -322,7 +322,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[-1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMulScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -335,7 +335,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterDivScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -348,7 +348,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMinScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -361,7 +361,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMaxScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -426,7 +426,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(ref, indices, updates) self.assertAllEqual(ref.read_value(), [True, True, True]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstraintArg(self): constraint = lambda x: x v = resource_variable_ops.ResourceVariable( @@ -466,32 +466,32 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.OutOfRangeError): state_ops.count_up_to(v, 1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFnDtype(self): v = resource_variable_ops.ResourceVariable( initial_value=lambda: 1, dtype=dtypes.float32, name="var0") self.assertEqual(dtypes.float32, v.value().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFnNoDtype(self): v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, name="var2") self.assertEqual(dtypes.int32, v.value().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeAllVariables(self): v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(1.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOperatorOverload(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(2.0, self.evaluate(v + v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignMethod(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -509,7 +509,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(4.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoad(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -561,7 +561,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): variable_def=trainable_variable.to_proto()) .trainable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSparseRead(self): with self.test_session(): init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) @@ -583,7 +583,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEquals(v._handle, w._handle) self.assertEquals(v._graph_element, w._graph_element) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignAddMethod(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -601,7 +601,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(4.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignSubMethod(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -619,7 +619,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(0.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDestroyResource(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -708,7 +708,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): v = resource_variable_ops.ResourceVariable( name="var4", initial_value=array_ops.ones(shape=[10, 20, 35])) @@ -842,7 +842,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(v, [1], [3]) self.assertAllEqual([1.0, 3.0], v.numpy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterUpdateInvalidArgs(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") # The exact error and message differ between graph construction (where the diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index fe5ad84c104502f0e09d3a963b406f49d6b97b71..957baf8c6089a6a033f54762fef290399d80cd09 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -81,6 +81,25 @@ class ScalarStateRNNCell(rnn_cell_impl.RNNCell): return (input_, state + 1) +class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell): + """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" + + @property + def output_size(self): + return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2)) + + @property + def state_size(self): + return tensor_shape.TensorShape([]) + + def zero_state(self, batch_size, dtype): + return array_ops.zeros([], dtype=dtypes.int32) + + def call(self, input_, state, scope=None): + concatenated = array_ops.concat((input_, input_), axis=-1) + return (input_, concatenated), state + 1 + + class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell): """RNN Cell its state as a TensorArray.""" @@ -108,7 +127,7 @@ class RNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvalidSequenceLengthShape(self): cell = Plus1RNNCell() if context.executing_eagerly(): @@ -122,7 +141,7 @@ class RNNTest(test.TestCase): dtype=dtypes.float32, sequence_length=[[4]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBatchSizeFromInput(self): cell = Plus1RNNCell() in_eager_mode = context.executing_eagerly() @@ -162,7 +181,7 @@ class RNNTest(test.TestCase): self.assertEqual(None, outputs.shape[0].value) self.assertEqual(None, state.shape[0].value) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScalarStateIsAccepted(self): cell = ScalarStateRNNCell() in_eager_mode = context.executing_eagerly() @@ -182,7 +201,29 @@ class RNNTest(test.TestCase): self.assertAllEqual([[[1], [2], [3], [4]]], outputs) self.assertAllEqual(4, state) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes + def testUnbalancedOutputIsAccepted(self): + cell = UnbalancedOutputRNNCell() + in_eager_mode = context.executing_eagerly() + + if in_eager_mode: + inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + else: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + + with self.test_session() as sess: + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, sequence_length=[4]) + if not in_eager_mode: + outputs, state = sess.run( + [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) + + self.assertIsInstance(outputs, tuple) + self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0]) + self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) + self.assertAllEqual(4, state) + + @test_util.run_in_graph_and_eager_modes def testTensorArrayStateIsAccepted(self): cell = TensorArrayStateRNNCell() in_eager_mode = context.executing_eagerly() @@ -215,7 +256,7 @@ class RNNTest(test.TestCase): cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output) self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCellsBuild(self): f32 = dtypes.float32 f64 = dtypes.float64 diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index faa4b49a8d7d8b0169f10592845d3d30a3996c41..f9b9c77bbf7e2a8afdbfbd0929a68856b8aae51c 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -369,7 +369,7 @@ class ScatterNdTest(test.TestCase): del input_ # input_ is not used in scatter_nd return array_ops.scatter_nd(indices, updates, shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvalidShape(self): # TODO(apassos) figure out how to unify these errors with self.assertRaises(errors.InvalidArgumentError diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 7368251ab69574cc6cba703e605f108c6ab45649..34e34d9d1b2034d8679844f051358f020a44587a 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -642,6 +642,29 @@ class TileTest(test.TestCase): err = gradient_checker.compute_gradient_error(a, [4, 2], tiled, [4, 4]) self.assertLess(err, 1e-3) + def testGradientWithSparseGradWithRank1(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + outputs = array_ops.gather(array_ops.tile(inputs, [3]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testGradientWithSparseGradWithRank3(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + inputs = array_ops.reshape(inputs, [-1, 1, 1]) + outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + def testShapeFunctionEdgeCases(self): # Unknown multiples shape. inp = constant_op.constant(0.0, shape=[4, 4, 4, 4]) diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index 5fc9bef21816e3a12f0d274bab1fc82a83546422..402f67619b41a5f13c6603eb6665974a09a8f4fb 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -225,7 +225,7 @@ class SliceTest(test.TestCase): self.assertAllEqual(m1.get_shape().as_list(), [1, 2, 3]) m2 = array_ops.slice(z, [0, 0, 0], [constant_op.constant(1) + 0, 2, -1]) - self.assertAllEqual(m2.get_shape().as_list(), [None, 2, None]) + self.assertAllEqual(m2.get_shape().as_list(), [1, 2, 3]) def _testGradientSlice(self, input_shape, slice_begin, slice_size): diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py index 27b39a626fcc6b2705bf9e797b5293ed3f1c7820..3847cebc7dcabd66c26a4e4551e5856c6a927a33 100644 --- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py @@ -300,6 +300,51 @@ class SerializeSparseTest(test.TestCase): sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse, dtypes.variant) + def testVariantSerializeDeserializeScalar(self): + with self.test_session(use_gpu=False) as sess: + indices_value = np.array([[]], dtype=np.int64) + values_value = np.array([37], dtype=np.int32) + shape_value = np.array([], dtype=np.int64) + sparse_tensor = self._SparseTensorPlaceholder() + serialized = sparse_ops.serialize_sparse( + sparse_tensor, out_type=dtypes.variant) + deserialized = sparse_ops.deserialize_sparse( + serialized, dtype=dtypes.int32) + deserialized_value = sess.run( + deserialized, + feed_dict={ + sparse_tensor.indices: indices_value, + sparse_tensor.values: values_value, + sparse_tensor.dense_shape: shape_value + }) + self.assertAllEqual(deserialized_value.indices, indices_value) + self.assertAllEqual(deserialized_value.values, values_value) + self.assertAllEqual(deserialized_value.dense_shape, shape_value) + + def testVariantSerializeDeserializeScalarBatch(self): + with self.test_session(use_gpu=False) as sess: + indices_value = np.array([[]], dtype=np.int64) + values_value = np.array([37], dtype=np.int32) + shape_value = np.array([], dtype=np.int64) + sparse_tensor = self._SparseTensorPlaceholder() + serialized = sparse_ops.serialize_sparse( + sparse_tensor, out_type=dtypes.variant) + stacked = array_ops.stack([serialized, serialized]) + deserialized = sparse_ops.deserialize_sparse(stacked, dtype=dtypes.int32) + deserialized_value = sess.run( + deserialized, + feed_dict={ + sparse_tensor.indices: indices_value, + sparse_tensor.values: values_value, + sparse_tensor.dense_shape: shape_value + }) + self.assertAllEqual(deserialized_value.indices, + np.array([[0], [1]], dtype=np.int64)) + self.assertAllEqual(deserialized_value.values, + np.array([37, 37], dtype=np.int32)) + self.assertAllEqual(deserialized_value.dense_shape, + np.array([2], dtype=np.int64)) + def _testDeserializeFailsWrongTypeHelper(self, serialize_fn, deserialize_fn, diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py index da116601f833cc6b471e383e030c5fbe93b52ac5..97f30daf4a9c9615e1b42a1ba94e693e166bbc1c 100644 --- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py @@ -21,13 +21,15 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import sparse_ops +import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import from tensorflow.python.platform import test class SparseSliceOpTest(test.TestCase): - def _SparseTensor_4x6(self): + def _SparseTensor_4x6(self, val_dtype=np.int64): # [0 | |2 | |4 |5 ] # [ |11| |13|14| ] # [20| | |23| |25] @@ -37,7 +39,7 @@ class SparseSliceOpTest(test.TestCase): [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype( np.int64) val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype( - np.int64) + val_dtype) shape = np.array([4, 6]).astype(np.int64) return sparse_tensor.SparseTensor(ind, val, shape) @@ -244,6 +246,22 @@ class SparseSliceOpTest(test.TestCase): self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35]) self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1]) + def testGradients(self): + sp_input = self._SparseTensor_4x6(val_dtype=np.float32) + start_and_size = [([0, 0], [4, 2]), + ([0, 2], [5, 2]), + ([0, 4], [5, 3])] + + with self.test_session(use_gpu=False): + for start, size in start_and_size: + sp_output = sparse_ops.sparse_slice(sp_input, start, size) + nnz_in = len(sp_input.values.eval()) + nnz_out = len(sp_output.values.eval()) + + err = gradient_checker.compute_gradient_error( + [sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,)) + self.assertLess(err, 1e-3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py index 8cfee3eb933afcea7a58d5632948b87b0c4c10df..419cd5ecdafab92910cd06fb18148796f70afb44 100644 --- a/tensorflow/python/kernel_tests/split_op_test.py +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -95,7 +95,7 @@ class SplitOpTest(test.TestCase): sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]}) self.assertTrue("Cannot infer num from shape" in str(context.exception)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExplicitNum(self): size_splits = array_ops.constant([2, 2, 6], dtype=dtypes.int32) value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @@ -109,7 +109,7 @@ class SplitOpTest(test.TestCase): self.assertAllEqual(r[1], value[2:4]) self.assertAllEqual(r[2], value[4:]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testListOfScalarTensors(self): a = math_ops.to_int32(5) b = math_ops.to_int32(6) @@ -168,7 +168,7 @@ class SplitOpTest(test.TestCase): offset += size_splits[i] self.assertAllEqual(result[i], inp[slices]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSpecialCasesVariable(self): self._testSpecialCasesVariable() for dtype in _TEST_DTYPES: @@ -210,13 +210,13 @@ class SplitOpTest(test.TestCase): self.assertAllEqual(np_ans[i], out[i]) self.assertShapeEqual(np_ans[i], tf_ans[i]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitRows(self): for dtype in _TEST_DTYPES: inp = self._makeData((4, 4), dtype) self._compare(inp, 0, 4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitCols(self): for dtype in _TEST_DTYPES: inp = self._makeData((4, 4), dtype) @@ -232,7 +232,7 @@ class SplitOpTest(test.TestCase): self.assertEqual(out[i].shape, expected_shape) self.assertEqual(expected_shape, tf_ans[i].get_shape()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEmpty(self): # Note: np.split returns a rank-0 empty ndarray # if the input ndarray is empty. @@ -244,7 +244,7 @@ class SplitOpTest(test.TestCase): self._testEmpty(inp, 2, 3, (8, 0, 7)) self._testEmpty(inp, 2, 7, (8, 0, 3)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testIdentity(self): for dtype in _TEST_DTYPES: inp = self._makeData((2, 2, 2), dtype) @@ -252,7 +252,7 @@ class SplitOpTest(test.TestCase): self._compare(inp, 1, 1) self._compare(inp, 2, 1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitDim0(self): for dtype in _TEST_DTYPES: self._compare(self._makeData((6, 10, 18), dtype), 0, 3) @@ -281,7 +281,7 @@ class SplitOpTest(test.TestCase): offset += length self.assertAllEqual(result[i], inp[slices]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandom(self): for dtype in _TEST_DTYPES: for _ in range(5): diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 1b935d5286729e9e802c56e90e2ae7ab72a6e080..0b3a396d6bf46fb46416662a9443ed7b5811e15c 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -150,7 +150,7 @@ class TemplateTest(test.TestCase): # Parameters are tied, so the loss should have gone down after training. self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_skip_stack_frames(self): first = traceback.format_stack() second = traceback.format_stack() @@ -158,7 +158,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(result)) self.assertNotEqual(len(first), len(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_with_name(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -204,7 +204,7 @@ class TemplateTest(test.TestCase): self.assertEqual(v1, v3) self.assertEqual("s1/dummy:0", v1.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_in_scope(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -221,7 +221,7 @@ class TemplateTest(test.TestCase): self.assertEqual("scope/s1/dummy:0", v1.name) self.assertEqual("scope/s1_1/dummy:0", v3.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_with_internal_reuse(self): tmpl1 = template.make_template("s1", internally_variable_scoped_function) tmpl2 = template.make_template("s1", internally_variable_scoped_function) @@ -237,13 +237,13 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl1("not_test") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_without_name(self): with self.assertRaisesRegexp( ValueError, "name cannot be None."): template.make_template(None, variable_scoped_function) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_make_template(self): # Test both that we can call it with positional and keywords. tmpl1 = template.make_template( @@ -266,7 +266,7 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_enforces_no_extra_trainable_variables_eager(self): tmpl = template.make_template("s", function_with_side_create, @@ -287,7 +287,7 @@ class TemplateTest(test.TestCase): trainable=False) self.assertEqual(tmpl(name="1"), tmpl(name="2")) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_internal_variable_reuse(self): def nested(): @@ -310,7 +310,7 @@ class TemplateTest(test.TestCase): self.assertEqual("s1/nested/x:0", v1.name) self.assertEqual("s1_1/nested/x:0", v3.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_templates(self): def nested_template(): @@ -360,7 +360,7 @@ class TemplateTest(test.TestCase): self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name) self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_templates_with_defun(self): def variable_scoped_function_no_return_value(trainable=True): @@ -429,7 +429,7 @@ class TemplateTest(test.TestCase): "a", partial, create_graph_function_=True) self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_immediate_scope_creation(self): # Create templates in scope a then call in scope b. make_template should # capture the scope the first time it is called, and make_immediate_template @@ -454,7 +454,7 @@ class TemplateTest(test.TestCase): self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_scope_access(self): # Ensure that we can access the scope inside the template, because the name # of that scope may be different from the name we pass to make_template, due @@ -479,7 +479,7 @@ class TemplateTest(test.TestCase): # Template is called at the top level, so there is no preceding "foo_2". self.assertEqual(tc.variable_scope.name, "blah") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_custom_getter(self): # Custom getter that maintains call count and forwards to true getter custom_getter_count = [0] @@ -512,7 +512,7 @@ class TemplateTest(test.TestCase): tmpl2() self.assertEqual(custom_getter_count[0], 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_fails_gracefully(self): for create_scope_now in [True, False]: def module_function_with_one_arg(inputs): @@ -535,7 +535,7 @@ class TemplateTest(test.TestCase): templatized_function(data) self.assertTrue(templatized_function._variables_created) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_name_scopes_for_variable_scopes(self): # Test that name scopes are not unnecessarily uniquified (but are # still uniquified when necessary). @@ -586,7 +586,7 @@ class TemplateTest(test.TestCase): "Second application of template should also get " "a freshly uniquified name scope.") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_global_variables(self): # Make sure global_variables are created. with variable_scope.variable_scope("foo"): @@ -608,7 +608,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.global_variables)) self.assertEqual(2, len(tb.global_variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainable_variables(self): # Make sure trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -632,7 +632,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.variables)) self.assertEqual(1, len(tb.variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_non_trainable_variables(self): # Make sure non_trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -675,7 +675,7 @@ class TemplateTest(test.TestCase): self.assertEqual(0, len(ta.local_variables)) self.assertEqual(1, len(tb.local_variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_make_template_with_defun(self): def variable_scoped_function_no_return_value(scope_name): diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index ea06357804f45bbba3a9e7e847659d47bf52bffb..6de6fbe7679fa8e95d3032b04fb81b43ac3a60d9 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -75,7 +75,7 @@ class TensorArrayTest(test.TestCase): super(TensorArrayTest, cls).tearDownClass() session_lib.Session.reset(cls._workers[0].target) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteRead(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -123,11 +123,11 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayWritePack(dtypes.complex128) self._testTensorArrayWritePack(dtypes.string) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWritePack(self): self._testTensorArrayWritePackMaybeLegacy() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEmptyTensorArrayPack(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -161,7 +161,7 @@ class TensorArrayTest(test.TestCase): convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0], [6.0, 7.0], [106.0, 107.0], [8.0, 9.0]]), c0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteConcat(self): self._testTensorArrayWriteConcat(dtypes.float32) self._testTensorArrayWriteConcat(dtypes.float64) @@ -184,7 +184,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros() @@ -200,7 +200,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros() @@ -251,7 +251,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayUnpackRead(dtypes.complex128) self._testTensorArrayUnpackRead(dtypes.string) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayUnpackRead(self): self._testTensorArrayUnpackReadMaybeLegacy() @@ -297,7 +297,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(convert([]).reshape(0, 2), d1) self.assertAllEqual(convert([[3.0, 301.0]]), d2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArraySplitRead(self): self._testTensorArraySplitRead(dtypes.float32) self._testTensorArraySplitRead(dtypes.float64) @@ -397,7 +397,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(t_g_ta_0, t_g_ta_1) self.assertAllEqual([[4.0, 5.0]], d_r1_0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteWrongIndexOrDataTypeFails(self): with self.test_session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) @@ -416,7 +416,7 @@ class TensorArrayTest(test.TestCase): "resizeable and size is: 3"): self.evaluate(ta.write(3, 3.0).flow) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadWrongIndexOrDataTypeFails(self): with self.test_session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) @@ -450,7 +450,7 @@ class TensorArrayTest(test.TestCase): "it has already been written to."): self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayConcatIncompatibleShapesFails(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -482,7 +482,7 @@ class TensorArrayTest(test.TestCase): with self.assertRaisesOpError("shape"): self.evaluate(w3.concat()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArraySplitIncompatibleShapesFails(self): with self.test_session(use_gpu=True): in_eager_mode = context.executing_eagerly() @@ -603,7 +603,7 @@ class TensorArrayTest(test.TestCase): self.assertAllClose(fed_value, sess.run(read_value, feed_dict={value: fed_value})) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMultiTensorArray(self): with self.test_session(use_gpu=True): h1 = tensor_array_ops.TensorArray( @@ -706,7 +706,7 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientWritePackConcatAndRead(self): self._testTensorArrayGradientWritePackConcatAndRead() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadTwice(self): with self.test_session(use_gpu=True): value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) @@ -811,14 +811,14 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientDynamicUnpackRead(self): self._testTensorArrayGradientDynamicUnpackRead() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCloseTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) self.evaluate(ta.close()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSizeTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -826,7 +826,7 @@ class TensorArrayTest(test.TestCase): s = ta.size() self.assertAllEqual(3, self.evaluate(s)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWriteCloseTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -924,7 +924,7 @@ class TensorArrayTest(test.TestCase): self.assertAllClose(grad_val.sum(axis=0), var_grad_t) self.assertAllClose(grad_val.sum(axis=0), state0_grad_t) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWhileLoopWritePackGradients(self): self._testWhileLoopWritePackGradients( dynamic_size=False, dtype=dtypes.float32) @@ -936,7 +936,7 @@ class TensorArrayTest(test.TestCase): self._testWhileLoopWritePackGradients( dynamic_size=True, dtype=dtypes.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradSerialTwoLoops(self): with self.test_session(use_gpu=True): def loop(x): @@ -1113,7 +1113,7 @@ class TensorArrayTest(test.TestCase): r5 = w5.read(0) self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def _testUnpackShape(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -1147,7 +1147,7 @@ class TensorArrayTest(test.TestCase): def testUnpackShape(self): self._testUnpackShape() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitShape(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -1289,7 +1289,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([10.0, -10.0], read_vals[1]) self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteGatherAndGradients(self): with self.test_session(use_gpu=True) as session: ta = tensor_array_ops.TensorArray( @@ -1433,7 +1433,7 @@ class TensorArrayTest(test.TestCase): self.assertFalse( [s for s in dev_stats[d] if "/TensorArray" in s.node_name]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayIdentity(self): with self.test_session(use_gpu=True): ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 2ee53df9317331dafd96f7884e9a8728cf443923..1e59a8c9bf58c92c6c8ef5c92ca6340027c985f8 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -57,7 +57,7 @@ class VariableScopeTest(test.TestCase): v1 = vs.get_variable("v", [1]) self.assertEqual(v, v1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResource(self): vs = variable_scope._get_default_variable_store() v1 = vs.get_variable("v", [1], use_resource=True) @@ -87,7 +87,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( set(expected_names), set([v.name for v in vs._vars.values()])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeInitializer(self): init = init_ops.constant_initializer(0.3) with variable_scope.variable_scope("tower0") as tower: @@ -100,7 +100,7 @@ class VariableScopeTest(test.TestCase): self.evaluate(variables_lib.variables_initializer([w])) self.assertAllClose(self.evaluate(w.value()), 0.3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeConstraint(self): constraint = lambda x: 0. * x with variable_scope.variable_scope("tower1") as tower: @@ -117,7 +117,7 @@ class VariableScopeTest(test.TestCase): variables_lib.global_variables_initializer().run() self.assertAllEqual(compat.as_bytes(v.eval()), b"") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeDType(self): with variable_scope.variable_scope("tower2") as tower: with variable_scope.variable_scope("foo", dtype=dtypes.float16): @@ -197,7 +197,7 @@ class VariableScopeTest(test.TestCase): self.assertAllEqual([v1, v2], [v3, v4]) f() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerVariablesStoreAddsToCollections(self): store = variable_scope.EagerVariableStore() with store.as_default(): @@ -214,7 +214,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerVariablesOutsideStoreNotAddedToCollections(self): if not context.executing_eagerly(): return @@ -223,7 +223,7 @@ class VariableScopeTest(test.TestCase): self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFromNonTensorValue(self): v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) self.evaluate(variables_lib.variables_initializer([v])) @@ -239,7 +239,7 @@ class VariableScopeTest(test.TestCase): with self.assertRaises(error): variable_scope.get_variable("x4", initializer={}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFromNonInitializer(self): # Test various dtypes with zeros initializer as following: types = [ @@ -294,7 +294,7 @@ class VariableScopeTest(test.TestCase): v_tower = variable_scope.get_variable("v", []) self.assertFalse(v_tower.value().device.startswith(caching_device)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeRegularizer(self): init = init_ops.constant_initializer(0.3) @@ -339,7 +339,7 @@ class VariableScopeTest(test.TestCase): losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(3, len(losses)) # No new loss added. - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeFromValue(self): init = constant_op.constant(0.1) w = variable_scope.get_variable("v", initializer=init) @@ -428,7 +428,7 @@ class VariableScopeTest(test.TestCase): sess.run(v0.initializer) sess.run(add) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetVariableScope(self): # Test the get_variable_scope() function and setting properties of result. init = init_ops.constant_initializer(0.3) @@ -449,7 +449,7 @@ class VariableScopeTest(test.TestCase): new_init = variable_scope.get_variable_scope().initializer self.assertEqual(new_init, None) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScope(self): with variable_scope.variable_scope("tower4") as tower: self.assertEqual(tower.name, "tower4") @@ -468,7 +468,7 @@ class VariableScopeTest(test.TestCase): with ops.name_scope("scope") as sc: self.assertEqual(sc, "tower6/tower4/scope/") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeNameScope(self): with ops.name_scope("testVarScopeNameScope1"): with variable_scope.variable_scope("tower") as tower: @@ -961,7 +961,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( constant_op.constant([], name="c").name, "another/inner/c:0") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetLocalVar(self): # Check that local variable respects naming. with variable_scope.variable_scope("outer") as outer: diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index eda036ece4a7d74e5752e80a4a2f4e4ada1b0a38..b8969a41aba1f8ee84233ce7ac398193183d292f 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -191,7 +191,7 @@ class Layer(base_layer.Layer): RuntimeError: If called with partioned variable regularization and eager execution is enabled. """ - + def _should_add_regularizer(variable, existing_variable_set): if isinstance(variable, tf_variables.PartitionedVariable): for var in variable: diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index ab49e37b90e183034ae7ab720fa92b06f39b2aed..298e96e711cbf8a0f625f95d737d1e7a83f4431d 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -39,7 +39,7 @@ from tensorflow.python.platform import test class BaseLayerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerProperties(self): layer = base_layers.Layer(name='my_layer') self.assertEqual(layer.variables, []) @@ -53,13 +53,13 @@ class BaseLayerTest(test.TestCase): layer = base_layers.Layer(name='my_layer', trainable=False) self.assertEqual(layer.trainable, False) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInt64Layer(self): layer = base_layers.Layer(name='my_layer', dtype='int64') layer.add_variable('my_var', [2, 2]) self.assertEqual(layer.name, 'my_layer') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAddWeight(self): layer = base_layers.Layer(name='my_layer') @@ -116,7 +116,7 @@ class BaseLayerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCall(self): class MyLayer(base_layers.Layer): @@ -132,7 +132,7 @@ class BaseLayerTest(test.TestCase): # op is only supported in GRAPH mode self.assertEqual(outputs.op.name, 'my_layer/Square') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeepCopy(self): class MyLayer(base_layers.Layer): @@ -155,7 +155,7 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer_copy._graph, layer._graph) self.assertEqual(layer_copy._private_tensor, layer._private_tensor) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScopeNaming(self): class PrivateLayer(base_layers.Layer): @@ -203,7 +203,7 @@ class BaseLayerTest(test.TestCase): my_layer_scoped1.apply(inputs) self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -230,7 +230,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1], [2]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecMinNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -258,7 +258,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[[1], [2]]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecMaxNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -286,7 +286,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1], [2]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecDtypeCheck(self): class CustomerLayer(base_layers.Layer): @@ -306,7 +306,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant(1.0, dtype=dtypes.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecAxesCheck(self): class CustomerLayer(base_layers.Layer): @@ -328,7 +328,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecShapeCheck(self): class CustomerLayer(base_layers.Layer): @@ -348,7 +348,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoInputSpec(self): class CustomerLayer(base_layers.Layer): @@ -369,7 +369,7 @@ class BaseLayerTest(test.TestCase): layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32', shape=(2, 3))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_count_params(self): dense = core_layers.Dense(16) dense.build((None, 4)) @@ -379,7 +379,7 @@ class BaseLayerTest(test.TestCase): with self.assertRaises(ValueError): dense.count_params() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDictInputOutput(self): class DictLayer(base_layers.Layer): @@ -589,6 +589,5 @@ class BaseLayerTest(test.TestCase): ValueError, 'Input graph and Layer graph are not the same'): layer.apply(constant_op.constant([[1.]])) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index cf45b07637108422f1c612390bb01efdad6d5bcf..040c1cddc0f2540eec5fcf3442bed3f4800bec7c 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -41,7 +41,7 @@ from tensorflow.python.platform import test class DenseTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDenseProperties(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') self.assertEqual(dense.units, 2) @@ -91,14 +91,14 @@ class DenseTest(test.TestCase): core_layers.Dense(5)(inputs) core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallTensorDot(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 4, 3), seed=1) outputs = dense(inputs) self.assertListEqual([5, 4, 2], outputs.get_shape().as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoBias(self): dense = core_layers.Dense(2, use_bias=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) @@ -112,7 +112,7 @@ class DenseTest(test.TestCase): self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.bias, None) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNonTrainable(self): dense = core_layers.Dense(2, trainable=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) @@ -125,7 +125,7 @@ class DenseTest(test.TestCase): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOutputShape(self): dense = core_layers.Dense(7, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 3), seed=1) @@ -165,7 +165,7 @@ class DenseTest(test.TestCase): dense = core_layers.Dense(4, name='my_dense') dense(inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testActivation(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') inputs = random_ops.random_uniform((5, 3), seed=1) @@ -325,7 +325,7 @@ class DenseTest(test.TestCase): var_key = 'test2/dense/kernel' self.assertEqual(var_dict[var_key].name, '%s:0' % var_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeOutputShape(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') ts = tensor_shape.TensorShape @@ -347,7 +347,7 @@ class DenseTest(test.TestCase): dense.compute_output_shape(ts([None, 4, 3])).as_list()) # pylint: enable=protected-access - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstraints(self): k_constraint = lambda x: x / math_ops.reduce_sum(x) b_constraint = lambda x: x / math_ops.reduce_max(x) @@ -369,7 +369,7 @@ def _get_variable_dict_from_varstore(): class DropoutTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDropoutProperties(self): dp = core_layers.Dropout(0.5, name='dropout') self.assertEqual(dp.rate, 0.5) @@ -377,7 +377,7 @@ class DropoutTest(test.TestCase): dp.apply(array_ops.ones(())) self.assertEqual(dp.name, 'dropout') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBooleanLearningPhase(self): dp = core_layers.Dropout(0.5) inputs = array_ops.ones((5, 3)) @@ -402,7 +402,7 @@ class DropoutTest(test.TestCase): np_output = sess.run(dropped, feed_dict={training: False}) self.assertAllClose(np.ones((5, 5)), np_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicNoiseShape(self): inputs = array_ops.ones((5, 3, 2)) noise_shape = [None, 1, None] diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 77fa2c1f66d2214dbb08e4d0ad3437fa4fe02822..fde3a83770280038b777a141693d117dace4b41f 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) { return x != static_cast(0); } +int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) { + bfloat16* const buffer = reinterpret_cast(buffer_raw); + const float start(buffer[0]); + const float delta = static_cast(buffer[1]) - start; + for (npy_intp i = 2; i < length; ++i) { + buffer[i] = static_cast(start + i * delta); + } + return 0; +} + // NumPy casts // Performs a NumPy array cast from type 'From' to 'To'. @@ -548,6 +558,7 @@ bool Initialize() { NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN; NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap; NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero; + NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill; Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type; npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr); diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 09d4b01fa43babdc09f8f255e79bbed539ddc04c..bc928cd9e5ef4d5a0ec0ce73e853e3e022a1f6fa 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase): np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)), atol=2e-2) + def testArange(self): + self.assertAllEqual( + np.arange(100, dtype=np.float32).astype(bfloat16), + np.arange(100, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16), + np.arange(-10.5, 7.8, 0.5, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16), + np.arange(-0., -7., -0.25, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16), + np.arange(-16384., 16384., 64., dtype=bfloat16)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h index 98354083c7e06103166a6fe535b153eaaf201c17..d4621d61ee98b9eb4b19213145059d242c88f40c 100644 --- a/tensorflow/python/lib/core/numpy.h +++ b/tensorflow/python/lib/core/numpy.h @@ -30,8 +30,8 @@ limitations under the License. #endif // Place `` before to avoid build failure in macOS. -#include #include +#include #include "numpy/arrayobject.h" #include "numpy/ufuncobject.h" diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 30c1a9c75986f242c6cf5a8aa2ed1b64938d2bda..57139986af7d2adc3670529d1bb22233f167ced0 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -55,37 +55,35 @@ struct PyCall { string token; // The device on which Tensors are stored; only used for EagerPyFunc. - Device* device; - - // True if and only if the op has been placed on a GPU. - bool gpu; + Device* device = nullptr; // True if the call is associated with an EagerPyFunc. - bool eager; + bool eager = false; // Inputs and outputs of this function invocation. std::vector ins; std::vector out; }; +bool IsCPUDevice(const Device* d) { + return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; +} + // Givens the 'call', prepares the token and inputs as a python tuple // that is appropriate for calling the trampoline. Status MakeArgTuple(const PyCall* call, PyObject** tuple) { int64 n = call->ins.size(); PyObject* lst = PyList_New(n); CHECK(lst); + // TFE_TensorHandle assumes that CPU is identified by nullptr. + Device* device = IsCPUDevice(call->device) ? nullptr : call->device; for (int64 i = 0; i < n; ++i) { PyObject* arg = nullptr; const Tensor& t = call->ins[i]; if (call->eager) { - if (call->gpu) { - arg = EagerTensorFromHandle( - new TFE_TensorHandle(t, call->device, call->device)); - } else { - // TFE_TensorHandle assumes that CPU is identified by `nullptr`. - arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr, nullptr)); - } + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, device, device)); if (arg == nullptr) { + Py_DECREF(lst); return errors::Internal("Unable to procure EagerTensor from Tensor."); } } else { @@ -97,8 +95,9 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { } PyList_SetItem(lst, i, arg); } - *tuple = Py_BuildValue("(sON)", call->token.c_str(), - call->gpu ? Py_True : Py_False, lst); + const char* device_name = + device == nullptr ? nullptr : device->attributes().name().c_str(); + *tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst); CHECK(*tuple); return Status::OK(); } @@ -167,9 +166,40 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. +// Validates that `output_tensor` is backed by memory in `expected_device` +// (which is assumed to be a local device, one on which the kernel was +// executed.) +// +// It may be nice to copy the tensor to the right device instead of failing if +// it isn't already there. This is left as a future exercise. The required +// device-copying logic is implemented in Python at the moment. tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + const Device* expected_device, const Tensor** output_tensor) { - return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor); + auto handle = EagerTensor_Handle(eager_tensor)->handle; + Device* actual_device = nullptr; + TF_RETURN_IF_ERROR(handle->Device(&actual_device)); + TF_RETURN_IF_ERROR(handle->Tensor(output_tensor)); + // actual_device may be nullptr, which implies local CPU. + if (expected_device == actual_device) return Status::OK(); + const string& expected_device_name = expected_device->attributes().name(); + if (actual_device == nullptr) { + if (!IsCPUDevice(expected_device)) { + return errors::Internal( + "expected the py_func to return a Tensor backed by memory in ", + expected_device_name, + ", but is actually backed by local host memory. This is a bug."); + } + return Status::OK(); + } + const string& actual_device_name = actual_device->attributes().name(); + if (actual_device_name != expected_device_name) { + return errors::Internal( + "expected the py_func to return a Tensor backed by memory in ", + expected_device_name, ", but is actually in ", actual_device_name, + ". This is a bug."); + } + return Status::OK(); } // Calls the registered py function through the trampoline. @@ -224,7 +254,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { const PyObject* item = PyList_GetItem(result, i); if (EagerTensor_CheckExact(item)) { const Tensor* tensor = nullptr; - s = ExtractTensorFromEagerTensor(item, &tensor); + s = ExtractTensorFromEagerTensor(item, call->device, &tensor); if (s.ok()) t = *tensor; } else { s = errors::FailedPrecondition( @@ -245,7 +275,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { DCHECK(call->eager); if (result != Py_None) { const Tensor* t = nullptr; - s = ExtractTensorFromEagerTensor(result, &t); + s = ExtractTensorFromEagerTensor(result, call->device, &t); if (s.ok()) call->out.push_back(*t); } } else if (PyArray_Check(result)) { @@ -449,13 +479,11 @@ class PyFuncOp : public OpKernel { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); eager_ = type_string() == "EagerPyFunc"; - gpu_ = ctx->device_type().type_string() == DEVICE_GPU; } void Compute(OpKernelContext* ctx) override { PyCall call; call.token = token_; - call.gpu = gpu_; call.eager = eager_; if (call.eager) { // Eager's C API uses `Device`, whereas `OpKernelContext` stores a @@ -464,6 +492,7 @@ class PyFuncOp : public OpKernel { if (call.device == nullptr) { ctx->CtxFailureWithWarning( errors::Internal("Unrecognized device class")); + return; } } @@ -508,9 +537,6 @@ class PyFuncOp : public OpKernel { private: string token_; - // True if and only if this op has been placed on a GPU. - bool gpu_; - // True if and only if this op should execute the python function eagerly, // i.e., if and only if the eager attribute is set. bool eager_; diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc index 572693b1cfafa04a7716e09464885faa4c92e299..6b6c82015fd2b73e410d64306ecbd613ccf1967c 100644 --- a/tensorflow/python/lib/core/py_util.cc +++ b/tensorflow/python/lib/core/py_util.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/python/lib/core/py_util.h" // Place `` before to avoid build failure in macOS. -#include #include +#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc1a25f420b434e6aa7d37cdf65f693e4d8c01a --- /dev/null +++ b/tensorflow/python/lib/io/tf_record_test.py @@ -0,0 +1,322 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf_record.TFRecordWriter and tf_record.tf_record_iterator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +import six + +from tensorflow.python.framework import errors_impl +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +prefix_path = "third_party/tensorflow/core/lib" + +# pylint: disable=invalid-name +TFRecordCompressionType = tf_record.TFRecordCompressionType +# pylint: enable=invalid-name + +# Edgar Allan Poe's 'Eldorado' +_TEXT = b"""Gaily bedight, + A gallant knight, + In sunshine and in shadow, + Had journeyed long, + Singing a song, + In search of Eldorado. + + But he grew old + This knight so bold + And o'er his heart a shadow + Fell as he found + No spot of ground + That looked like Eldorado. + + And, as his strength + Failed him at length, + He met a pilgrim shadow + 'Shadow,' said he, + 'Where can it be + This land of Eldorado?' + + 'Over the Mountains + Of the Moon' + Down the Valley of the Shadow, + Ride, boldly ride,' + The shade replied, + 'If you seek for Eldorado!' + """ + + +class TFCompressionTestCase(test.TestCase): + + def setUp(self): + super(TFCompressionTestCase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + def _Record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _CreateFiles(self, options=None, prefix=""): + filenames = [] + for i in range(self._num_files): + name = prefix + "tfrecord.%d.txt" % i + records = [self._Record(i, j) for j in range(self._num_records)] + fn = self._WriteRecordsToFile(records, name, options) + filenames.append(fn) + return filenames + + def _WriteRecordsToFile(self, records, name="tfrecord", options=None): + fn = os.path.join(self.get_temp_dir(), name) + with tf_record.TFRecordWriter(fn, options=options) as writer: + for r in records: + writer.write(r) + return fn + + def _ZlibCompressFile(self, infile, name="tfrecord.z"): + # zlib compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = zlib.compress(f.read()) + + zfn = os.path.join(self.get_temp_dir(), name) + with open(zfn, "wb") as f: + f.write(cdata) + return zfn + + def _GzipCompressFile(self, infile, name="tfrecord.gz"): + # gzip compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = f.read() + + gzfn = os.path.join(self.get_temp_dir(), name) + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + return gzfn + + def _ZlibDecompressFile(self, infile, name="tfrecord"): + with open(infile, "rb") as f: + cdata = zlib.decompress(f.read()) + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + def _GzipDecompressFile(self, infile, name="tfrecord"): + with gzip.GzipFile(infile, "rb") as f: + cdata = f.read() + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + +class TFRecordWriterTest(TFCompressionTestCase): + + def setUp(self): + super(TFRecordWriterTest, self).setUp() + + def _AssertFilesEqual(self, a, b, equal): + for an, bn in zip(a, b): + with open(an, "rb") as af, open(bn, "rb") as bf: + if equal: + self.assertEqual(af.read(), bf.read()) + else: + self.assertNotEqual(af.read(), bf.read()) + + def testWriteReadZLibFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + zlib_files = [ + self._ZlibCompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, zlib_files, False) + + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + compressed_files = self._CreateFiles(options, prefix="compressed") + self._AssertFilesEqual(compressed_files, zlib_files, True) + + # Decompress compress and verify same. + uncompressed_files = [ + self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + def testWriteReadGzipFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + gzip_files = [ + self._GzipCompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, gzip_files, False) + + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + compressed_files = self._CreateFiles(options, prefix="compressed") + + # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so + # compressed_files can't be compared with gzip_files + + # Decompress compress and verify same. + uncompressed_files = [ + self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + +class TFRecordWriterZlibTest(TFCompressionTestCase): + + def testZLibFlushRecord(self): + original = [b"small record"] + fn = self._WriteRecordsToFile(original, "small_record") + with open(fn, "rb") as h: + buff = h.read() + + # creating more blocks and trailing blocks shouldn't break reads + compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) + + output = b"" + for c in buff: + if isinstance(c, int): + c = six.int2byte(c) + output += compressor.compress(c) + output += compressor.flush(zlib.Z_FULL_FLUSH) + + output += compressor.flush(zlib.Z_FULL_FLUSH) + output += compressor.flush(zlib.Z_FULL_FLUSH) + output += compressor.flush(zlib.Z_FINISH) + + # overwrite the original file with the compressed data + with open(fn, "wb") as h: + h.write(output) + + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(fn, options=options)) + self.assertEqual(actual, original) + + def testZlibReadWrite(self): + """Verify that files produced are zlib compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord") + zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z") + + # read the compressed contents and verify. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(zfn, options=options)) + self.assertEqual(actual, original) + + def testZlibReadWriteLarge(self): + """Verify that writing large contents also works.""" + + # Make it large (about 5MB) + original = [_TEXT * 10240] + fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") + zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") + + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(zfn, options=options)) + self.assertEqual(actual, original) + + def testGzipReadWrite(self): + """Verify that files produced are gzip compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") + gzfn = self._GzipCompressFile(fn, "tfrecord.gz") + + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + actual = list(tf_record.tf_record_iterator(gzfn, options=options)) + self.assertEqual(actual, original) + + +class TFRecordIteratorTest(TFCompressionTestCase): + + def setUp(self): + super(TFRecordIteratorTest, self).setUp() + self._num_records = 7 + + def testIterator(self): + records = [self._Record(0, i) for i in range(self._num_records)] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(records, "compressed_records", options) + + reader = tf_record.tf_record_iterator(fn, options) + for expected in records: + record = next(reader) + self.assertAllEqual(expected, record) + with self.assertRaises(StopIteration): + record = next(reader) + + def testWriteZlibRead(self): + """Verify compression with TFRecordWriter is zlib library compatible.""" + original = [b"foo", b"bar"] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z", + options) + + zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) + self.assertEqual(actual, original) + + def testWriteZlibReadLarge(self): + """Verify compression for large records is zlib library compatible.""" + # Make it large (about 5MB) + original = [_TEXT * 10240] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z", + options) + zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) + self.assertEqual(actual, original) + + def testWriteGzipRead(self): + original = [b"foo", b"bar"] + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz", + options) + + gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord") + actual = list(tf_record.tf_record_iterator(gzfn)) + self.assertEqual(actual, original) + + def testBadFile(self): + """Verify that tf_record_iterator throws an exception on bad TFRecords.""" + fn = os.path.join(self.get_temp_dir(), "bad_file") + with tf_record.TFRecordWriter(fn) as writer: + writer.write(b"123") + fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") + with open(fn, "rb") as f: + with open(fn_truncated, "wb") as f2: + # DataLossError requires that we've written the header, so this must + # be at least 12 bytes. + f2.write(f.read(14)) + with self.assertRaises(errors_impl.DataLossError): + for _ in tf_record.tf_record_iterator(fn_truncated): + pass + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 3678bd4c1f6a4500622b6d9e8334cb1ebae46578..fe459a96b98733f8a706b0c3b84000c5a74894ad 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -568,7 +568,6 @@ ops.NotDifferentiable("Size") @ops.RegisterGradient("Tile") def _TileGrad(op, grad): """Sum reduces grad along the tiled dimensions.""" - assert isinstance(grad, ops.Tensor) input_shape = array_ops.shape(op.inputs[0]) # We interleave multiples and input_shape to get split_shape, # reshape grad to split_shape, and reduce along all even @@ -581,6 +580,13 @@ def _TileGrad(op, grad): split_shape = array_ops.reshape( array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) axes = math_ops.range(0, array_ops.size(split_shape), 2) + # Sum reduces grad along the first dimension for IndexedSlices + if isinstance(grad, ops.IndexedSlices): + grad = math_ops.unsorted_segment_sum( + grad.values, + math_ops.mod(grad.indices, input_shape[0]), + input_shape[0]) + split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) # Fix shape inference if not context.executing_eagerly(): diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index fae63b1132cca527c6bc5d5f9f5c8be2952d8f3c..361667ec49aba9705787c3c7ac096add36afb40b 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import gen_math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_array_ops import * +from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -2609,14 +2610,6 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") -@tf_export("reverse") -def reverse(tensor, axis, name=None): - return gen_array_ops.reverse_v2(tensor, axis, name) - - -reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ - - # pylint: disable=redefined-builtin @tf_export("reverse_sequence") @deprecation.deprecated_args( diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py index 2a2bcdd9d69b7a0aed1e7f3d3197cf6d7dd98451..9ebb607c475d444bfc78369b8f5415ac93b0dee2 100644 --- a/tensorflow/python/ops/boosted_trees_ops.py +++ b/tensorflow/python/ops/boosted_trees_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import resources # Re-exporting ops used by other modules. # pylint: disable=unused-import from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py index a05fd15eca12a423bf02dfb13044dd1f7630b99c..98668facd5bc56892fa00f258dfebcbe93c063da 100644 --- a/tensorflow/python/ops/collective_ops.py +++ b/tensorflow/python/ops/collective_ops.py @@ -22,7 +22,7 @@ from tensorflow.python.ops import gen_collective_ops def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op, - subdiv_offsets=(0)): + subdiv_offsets=(0,)): """Reduces tensors collectively, across devices. Args: diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py index 8e16cffdf4917ba361a3c313047e39af514273bc..9cc64ef9f631faf2f76c3dbb3e70e1f37bbe4b1a 100644 --- a/tensorflow/python/ops/collective_ops_test.py +++ b/tensorflow/python/ops/collective_ops_test.py @@ -37,11 +37,11 @@ class CollectiveOpTest(test.TestCase): with ops.device('/CPU:0'): in0 = constant_op.constant(t0) colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key, - 'Add', 'Div', [0]) + 'Add', 'Div') with ops.device('/CPU:1'): in1 = constant_op.constant(t1) colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key, - 'Add', 'Div', [0]) + 'Add', 'Div') run_options = config_pb2.RunOptions() run_options.experimental.collective_graph_key = 1 results = sess.run([colred0, colred1], options=run_options) diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..76173e0f309b80402a15acdab5d2af49f35de741 --- /dev/null +++ b/tensorflow/python/ops/cond_v2.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 wrapper module. + +This imports the cond_v2 method and all necessary dependencies (this is to avoid +circular dependencies in the cond_v2 implementation). See cond_v2_impl for more +information. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.python.framework import function +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.ops import gradients_impl + +from tensorflow.python.ops.cond_v2_impl import cond_v2 +# pylint: enable=unused-import diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/python/ops/cond_v2_impl.py similarity index 89% rename from tensorflow/contrib/control_flow/python/cond_v2.py rename to tensorflow/python/ops/cond_v2_impl.py index 90371cd8d70db11dc77af02a2b1fd2a90f3dcf44..d310f83dca97889157eb078b11a3ca51caae2fc2 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -17,23 +17,32 @@ This is a version of cond that emits a single If op, as well as the gradient function for If ops produced by cond_v2. This will eventually replace the current tf.cond implementation once it reaches feature and performance parity. + +NOTE: most users of cond_v2 should import cond_v2, not this module! This module +does not contain all the necessary imports to prevent circular dependencies, +while cond_v2 does. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import c_api_util -from tensorflow.python.framework import function -from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_functional_ops -from tensorflow.python.ops import gradients_impl from tensorflow.python.util import compat +# The following modules cannot be imported directly because they cause circular +# dependencies. These are set in each corresponding module. +_function = None +_function_def_to_graph = None +_gradients_impl = None + # NOTE(skyewm): TensorFlow uses protected class methods and fields to signify # that they aren't part of the official public API. These protected members # often need to be used by implementation code however. Rather than litter the @@ -58,14 +67,14 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): func_name_prefix = scope.replace("/", "_") - true_graph = function.func_graph_from_py_func( + true_graph = _function.func_graph_from_py_func( true_fn, [], [], name="%strue" % func_name_prefix, device=caller_device, colocation_stack=caller_colocation_stack, collections_ref=caller_collection_ref, container=caller_container) - false_graph = function.func_graph_from_py_func( + false_graph = _function.func_graph_from_py_func( false_fn, [], [], name="%sfalse" % func_name_prefix, device=caller_device, @@ -103,6 +112,22 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): _create_new_tf_function(false_graph), name=scope) + # Set the flag to enable lowering on the `if` op if necessary + # Lowering allows cond_v2 to avoid some of the limitations of Functions, + # allowing users to specify devices & colocation inside of cond_v2 branches, + # and enabling non-strict evaluation & partial pruning of cond_v2 branches. + # This brings cond_v2 closer to feature parity with tf.cond. + # + # However, we do not lower `If` in the XLA context because it is easier for + # XLA to apply its own optimizations when dealing with un-lowered `If` + # operators than with lowered switch/merge control flow. + # + # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output + if_op = tensors[0].op + if not control_flow_util.IsInXLAContext(if_op): + if_op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + return tensors[:num_cond_outputs] @@ -169,11 +194,13 @@ def _get_func_graphs(if_op): A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. """ def _get_func_graph_for_branch(branch_name): + """Generates and returns a _FuncGraph for the given branch.""" extra_inputs = if_op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in extra_inputs] func_name = if_op.get_attr(branch_name).name fdef = if_op.graph._get_function(func_name).definition - func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) + func_graph = _function_def_to_graph.function_def_to_graph( + fdef, input_shapes) func_graph.extra_inputs = extra_inputs func_graph.extra_args = func_graph.inputs func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) @@ -205,7 +232,7 @@ def _grad_fn(func_graph, grads): ys = [] grad_ys = [] for y, grad_y in zip(func_graph.outputs, grads): - if not gradients_impl._IsTrainable(y): + if not _gradients_impl._IsTrainable(y): continue ys.append(y) grad_ys.append(grad_y) @@ -214,7 +241,7 @@ def _grad_fn(func_graph, grads): # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _get_grad_inputs. - result = gradients_impl._GradientsHelper( + result = _gradients_impl._GradientsHelper( ys, func_graph.inputs, grad_ys=grad_ys, src_graph=func_graph) @@ -230,8 +257,8 @@ def _grad_fn(func_graph, grads): def _create_grad_func(func_graph, grads, name): """Returns the _FuncGraph representation of _grad_fn.""" - return function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), - [], [], name) + return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) def _get_grad_inputs(if_op, cond_graph, grad_graph): @@ -297,8 +324,8 @@ def _create_new_tf_function(func_graph): # TODO(b/109833212): this sucks, we're serializing the TF_Function*, # deserializing it into a Python FunctionDef, then reserializing it to create # a new TF_Function that we add to the graph. - fdef = function.function_def_from_tf_function(c_func) - defined_func = function._from_definition(fdef) + fdef = _function.function_def_from_tf_function(c_func) + defined_func = _function._from_definition(fdef) defined_func.add_to_graph(ops.get_default_graph()) return func_graph.name diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 2e5a801f8e96aa1266695a1440d98e6bff53607c..fc37805c79916ca9108481f7b6e69c381c2ff9d2 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -24,6 +24,7 @@ from __future__ import print_function import abc import collections import functools +import os import six @@ -38,6 +39,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_util as util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_control_flow_ops @@ -57,6 +59,10 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export + +_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0" + + # We override the 'tuple' for a control flow op, so we keep python's # existing 'tuple' for later use in this module. _basetuple = tuple @@ -596,7 +602,6 @@ def _EnforceShapeInvariant(merge_var, next_var): enter = merge_var.op.inputs[0].op assert util.IsLoopEnter(enter) input_t = enter.inputs[0] - assert input_t.shape == m_shape raise ValueError( "Input tensor '%s' enters the loop with shape %s, but has shape %s " "after one iteration. To allow the shape to vary across iterations, " @@ -1994,6 +1999,9 @@ def cond(pred, ``` """ + if _ENABLE_COND_V2: + return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name) + # We needed to make true_fn/false_fn keyword arguments for # backwards-compatibility. This check exists so that we can convert back to # having them be positional arguments. @@ -2935,9 +2943,10 @@ class WhileContext(ControlFlowContext): loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) try: self.Enter() - # _BuildLoop calls _update_input in several places. _lock ensures a - # Session.run call cannot occur between creating and mutating new ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # _BuildLoop calls _update_input in several places. _mutation_lock() + # ensures a Session.run call cannot occur between creating and mutating + # new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access original_body_result, exit_vars = self._BuildLoop( pred, body, original_loop_vars, loop_vars, shape_invariants) finally: @@ -3126,6 +3135,7 @@ def while_loop(cond, happen is that the thread updating `x` can never get ahead of the counter thread because the thread incrementing `x` depends on the value of the counter. + ```python import tensorflow as tf @@ -3340,12 +3350,6 @@ def group(*inputs, **kwargs): if not hasattr(inp, "device"): raise TypeError("Expected tf.group() expected Tensor arguments not " "'%s' with type '%s'" % (inp, type(inp))) - if not hasattr(inp, "device"): - if isinstance(inp, list): - raise TypeError("To call tf.group() with a list, use " - "tf.group(*[...]) not tf.group([...]).") - raise TypeError("Expected tf.group() expected Tensor arguments not " - "'%s' with type '%s'" % (inp, type(inp))) dev = inp.device if dev in ops_on_device: ops_on_device[dev].append(inp) diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 59bb925df0f25b3bf88112bc3eb1b13b21ace414..43fe045bcb10d2fc383381f92f2bc44c5362ac7d 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -939,7 +939,7 @@ class CaseTest(test_util.TensorFlowTestCase): class WhileLoopTestCase(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWhileLoopWithSingleVariable(self): i = constant_op.constant(0) c = lambda i: math_ops.less(i, 10) @@ -948,7 +948,7 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase): self.assertEqual(self.evaluate(r), 10) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self): i = constant_op.constant(0) c = lambda i: math_ops.less(i, 10) diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py index 907df85cd954d2a897ba9a0c4b21be8586859380..aacdaa7ad019d8aae2d0b533cde8412ab0f0fa22 100644 --- a/tensorflow/python/ops/conv2d_benchmark.py +++ b/tensorflow/python/ops/conv2d_benchmark.py @@ -21,6 +21,8 @@ from __future__ import print_function import itertools import time +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,22 +30,32 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables +from tensorflow.python.platform import flags from tensorflow.python.platform import test +FLAGS = flags.FLAGS -def build_graph(device, input_shape, filter_shape, strides, padding, dtype, - num_iters, warmup_iters): +flags.DEFINE_boolean( + "enable_layout_optimizer", False, + "If true, enables layout optimizer to update input data format for faster " + "execution of convolution ops.") + + +def build_graph(device, dtype, data_format, input_shape, filter_shape, strides, + padding, num_iters, warmup_iters): """builds a graph containing a sequence of conv2d operations. Args: device: String, the device to run on. + dtype: Data type for the convolution. + data_format: A string from: "NHWC" or "NCHW". Data format for input and + output data. input_shape: Shape of the input tensor. filter_shape: Shape of the filter tensor. strides: A list of ints. 1-D of length 4. The stride of sliding window for each dimension of input. padding: A string from: "SAME", "VALID". The type of padding algorithm to use. - dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. warmup_iters: number of iterations for warmup runs. @@ -57,22 +69,23 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype, random_ops.truncated_normal(filter_shape, dtype=dtype)) outputs = [] - conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC") + conv2d_op = nn_ops.conv2d( + inp, filt, strides, padding, data_format=data_format) outputs.append(conv2d_op) for _ in range(1, num_iters): with ops.control_dependencies([conv2d_op]): conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) outputs.append(conv2d_op) warmup_groups = [] warmup_conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) warmup_groups.append(warmup_conv2d_op) for _ in range(1, warmup_iters): with ops.control_dependencies([warmup_conv2d_op]): warmup_conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) warmup_groups.append(warmup_conv2d_op) return control_flow_ops.group(*warmup_groups), control_flow_ops.group( *outputs) @@ -81,12 +94,15 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype, class Conv2DBenchmark(test.Benchmark): """Benchmark conv2d!""" - def _run_graph(self, device, input_shape, filter_shape, strides, padding, - dtype, num_iters, warmup_iters): + def _run_graph(self, device, dtype, data_format, input_shape, filter_shape, + strides, padding, num_iters, warmup_iters): """runs the graph and print its execution time. Args: device: String, the device to run on. + dtype: Data type for the convolution. + data_format: A string from: "NHWC" or "NCHW". Data format for input and + output data. input_shape: Shape of the input tensor. filter_shape: Shape of the filter tensor. strides: A list of ints. 1-D of length 4. The stride of sliding @@ -94,7 +110,6 @@ class Conv2DBenchmark(test.Benchmark): padding: A string from: "SAME", "VALID". The type of padding algorithm to use. num_iters: Number of iterations to run the benchmark. - dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. warmup_iters: number of iterations for warmup runs. @@ -103,10 +118,27 @@ class Conv2DBenchmark(test.Benchmark): """ graph = ops.Graph() with graph.as_default(): - warmup_outputs, outputs = build_graph(device, input_shape, filter_shape, - strides, padding, dtype, num_iters, - warmup_iters) - with session_lib.Session(graph=graph) as session: + warmup_outputs, outputs = build_graph(device, dtype, data_format, + input_shape, filter_shape, strides, + padding, num_iters, warmup_iters) + + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.opt_level = -1 + rewrite_options = config.graph_options.rewrite_options + + # Disable layout optimizer to not change input data_format. + rewrite_options.layout_optimizer = ( + rewriter_config_pb2.RewriterConfig.ON if FLAGS.enable_layout_optimizer + else rewriter_config_pb2.RewriterConfig.OFF) + # Convolution ops are effectively noop in the test graph as we are not + # fetching the convolution outputs. Disable dependency optimizer to not + # remove the conv ops. + rewrite_options.dependency_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + + with session_lib.Session(graph=graph, config=config) as session: + # TODO(hinsu): Use run_op_benchmark method from test.Benchmark to run + # benchmark along with warmup. variables.global_variables_initializer().run() # warmup runs session.run(warmup_outputs) @@ -114,20 +146,21 @@ class Conv2DBenchmark(test.Benchmark): start_time = time.time() session.run(outputs) duration = (time.time() - start_time) / num_iters - print("%s %s inputshape:%s filtershape:%s strides:%s padding:%s " + print("%s %s %s inputshape:%s filtershape:%s strides:%s padding:%s " "%d iters: %.8f sec" % - (device, str(dtype), str(input_shape).replace(" ", ""), - str(filter_shape).replace(" ", ""), + (device, str(dtype), data_format, str(input_shape).replace( + " ", ""), str(filter_shape).replace(" ", ""), str(strides).replace(" ", ""), padding, num_iters, duration)) name_template = ( - "conv2d_{device}_{datatype}_input_shape_{inputshape}_" + "conv2d_{device}_{datatype}_{data_format}_input_shape_{inputshape}_" "filter_shape_{filtershape}_strides_{strides}_padding_{padding}") self.report_benchmark( name=name_template.format( device=device, datatype=str(dtype), + data_format=str(data_format), inputshape=str(input_shape).replace(" ", ""), filtershape=str(filter_shape).replace(" ", ""), strides=str(strides).replace(" ", ""), @@ -140,24 +173,37 @@ class Conv2DBenchmark(test.Benchmark): def benchmark_conv2d(self): print("conv2d benchmark:") - h = 500 - w = 500 - fh = 3 - fw = 3 - input_shapes = [] - filter_shapes = [] data_types = [dtypes.float32, dtypes.float16] - for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]): - input_shapes += [[b, h, w, c]] - filter_shapes += [[fh, fw, c, b]] - strides = [[1, 2, 2, 1]] + data_formats = ["NHWC", "NCHW"] + in_channels = list(range(3, 16)) + out_channels = [4, 16, 32] + hw_strides = [[2, 2]] paddings = ["VALID", "SAME"] - for ishape, fshape in zip(input_shapes, filter_shapes): - for dtype in data_types: - for stride in strides: - for padding in paddings: - self._run_graph("gpu", ishape, fshape, stride, padding, dtype, 80, - 2) + + args_lists = [ + data_types, data_formats, in_channels, out_channels, hw_strides, + paddings + ] + for args in itertools.product(*args_lists): + dtype, data_format, in_channel, out_channel, hw_stride, padding = args + + # Keep batch size same as out channels just to reduce the number of + # different configurations to benchmark. + batch_size = out_channel + h, w, fh, fw = 500, 500, 3, 3 + if data_format == "NHWC": + ishape = [batch_size, h, w, in_channel] + stride = [1] + hw_stride + [1] + elif data_format == "NCHW": + ishape = [batch_size, in_channel, h, w] + stride = [1, 1] + hw_stride + else: + raise ValueError("Unknown data_format: " + str(data_format)) + fshape = [fh, fw, in_channel, out_channel] + num_iters = 80 + warmup_iters = 2 + self._run_graph("gpu", dtype, data_format, ishape, fshape, stride, + padding, num_iters, warmup_iters) if __name__ == "__main__": diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py index f28f76b6c42a861c51c1fc06f99fa73b71b625a9..99d30b0bd112b62c625a94b43da589f9717d0774 100644 --- a/tensorflow/python/ops/distributions/beta.py +++ b/tensorflow/python/ops/distributions/beta.py @@ -84,13 +84,24 @@ class Beta(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. + Warning: The samples can be zero due to finite precision. + This happens more often when some of the concentrations are very small. + Make sure to round the samples to `np.finfo(dtype).tiny` before computing the + density. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples ```python # Create a batch of three Beta distributions. alpha = [1, 2, 3] beta = [1, 2, 3] - dist = Beta(alpha, beta) + dist = tf.distributions.Beta(alpha, beta) dist.sample([4, 5]) # Shape [4, 5, 3] @@ -106,7 +117,7 @@ class Beta(distribution.Distribution): # Create batch_shape=[2, 3] via parameter broadcast: alpha = [[1.], [2]] # Shape [2, 1] beta = [3., 4, 5] # Shape [3] - dist = Beta(alpha, beta) + dist = tf.distributions.Beta(alpha, beta) # alpha broadcast as: [[1., 1, 1,], # [2, 2, 2]] @@ -122,6 +133,18 @@ class Beta(distribution.Distribution): dist.prob(x) # Shape [2, 3] ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + alpha = tf.constant(1.0) + beta = tf.constant(2.0) + dist = tf.distributions.Beta(alpha, beta) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [alpha, beta]) + ``` + """ def __init__(self, @@ -165,7 +188,7 @@ class Beta(distribution.Distribution): dtype=self._total_concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration1, self._concentration0, diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py index 72567e62f78665947c001282c9c4f4929e9ea0ef..9104a1d071af3d7b7d40838148f2e49301fa39ba 100644 --- a/tensorflow/python/ops/distributions/dirichlet.py +++ b/tensorflow/python/ops/distributions/dirichlet.py @@ -90,13 +90,24 @@ class Dirichlet(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. + Warning: Some components of the samples can be zero due to finite precision. + This happens more often when some of the concentrations are very small. + Make sure to round the samples to `np.finfo(dtype).tiny` before computing the + density. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples ```python # Create a single trivariate Dirichlet, with the 3rd class being three times # more frequent than the first. I.e., batch_shape=[], event_shape=[3]. alpha = [1., 2, 3] - dist = Dirichlet(alpha) + dist = tf.distributions.Dirichlet(alpha) dist.sample([4, 5]) # shape: [4, 5, 3] @@ -118,7 +129,7 @@ class Dirichlet(distribution.Distribution): # Create batch_shape=[2], event_shape=[3]: alpha = [[1., 2, 3], [4, 5, 6]] # shape: [2, 3] - dist = Dirichlet(alpha) + dist = tf.distributions.Dirichlet(alpha) dist.sample([4, 5]) # shape: [4, 5, 2, 3] @@ -129,6 +140,17 @@ class Dirichlet(distribution.Distribution): dist.prob(x) # shape: [2] ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + alpha = tf.constant([1.0, 2.0, 3.0]) + dist = tf.distributions.Dirichlet(alpha) + samples = dist.sample(5) # Shape [5, 3] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, alpha) + ``` + """ def __init__(self, @@ -165,7 +187,7 @@ class Dirichlet(distribution.Distribution): dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._total_concentration], @@ -290,10 +312,8 @@ class Dirichlet(distribution.Distribution): if not self.validate_args: return x return control_flow_ops.with_dependencies([ - check_ops.assert_positive( - x, - message="samples must be positive"), - distribution_util.assert_close( + check_ops.assert_positive(x, message="samples must be positive"), + check_ops.assert_near( array_ops.ones([], dtype=self.dtype), math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py index 24bc3f3d3eb06a01d5173cb6c7fb0f09172a0587..4325a14449dd9a13dabb65a240ede452544c761a 100644 --- a/tensorflow/python/ops/distributions/exponential.py +++ b/tensorflow/python/ops/distributions/exponential.py @@ -103,9 +103,6 @@ class Exponential(gamma.Gamma): allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) - # While the Gamma distribution is not reparameterizable, the exponential - # distribution is. - self._reparameterization_type = True self._parameters = parameters self._graph_parents += [self._rate] diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py index 163a27f7585518c321dd1ea59b71029e2ae6a1e7..b631f0247c59e518fbd4925065d33345d4ea8e47 100644 --- a/tensorflow/python/ops/distributions/gamma.py +++ b/tensorflow/python/ops/distributions/gamma.py @@ -55,7 +55,7 @@ class Gamma(distribution.Distribution): ```none pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z - Z = Gamma(alpha) beta**alpha + Z = Gamma(alpha) beta**(-alpha) ``` where: @@ -85,14 +85,35 @@ class Gamma(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. - WARNING: This distribution may draw 0-valued samples for small `concentration` - values. See note in `tf.random_gamma` docstring. + Warning: The samples of this distribution are always non-negative. However, + the samples that are smaller than `np.finfo(dtype).tiny` are rounded + to this value, so it appears more often than it should. + This should only be noticeable when the `concentration` is very small, or the + `rate` is very large. See note in `tf.random_gamma` docstring. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) #### Examples ```python - dist = Gamma(concentration=3.0, rate=2.0) - dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + dist = tf.distributions.Gamma(concentration=3.0, rate=2.0) + dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + ``` + + Compute the gradients of samples w.r.t. the parameters: + + ```python + concentration = tf.constant(3.0) + rate = tf.constant(2.0) + dist = tf.distributions.Gamma(concentration, rate) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [concentration, rate]) ``` """ @@ -141,7 +162,7 @@ class Gamma(distribution.Distribution): dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._rate], diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py index 20a2d16181442bede797ded5e4d3ebbd3d55ca2b..e0cf6f86f10eec76bf94cd74f64202c452425886 100644 --- a/tensorflow/python/ops/distributions/student_t.py +++ b/tensorflow/python/ops/distributions/student_t.py @@ -80,6 +80,12 @@ class StudentT(distribution.Distribution): variance. However it is not actually the std. deviation; the Student's t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`. + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples Examples of initialization of one or a batch of distributions. @@ -118,6 +124,19 @@ class StudentT(distribution.Distribution): dist.prob(3.0) ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + df = tf.constant(2.0) + loc = tf.constant(2.0) + scale = tf.constant(11.0) + dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [df, loc, scale]) + ``` + """ # pylint: enable=line-too-long @@ -168,7 +187,7 @@ class StudentT(distribution.Distribution): (self._df, self._loc, self._scale)) super(StudentT, self).__init__( dtype=self._scale.dtype, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 401676bf842b4dd76fc64b5f4599804a0f3a46f8..3e480a79f52b178789a2d34e98c6af31048c07b1 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -36,43 +36,6 @@ from tensorflow.python.ops import nn from tensorflow.python.util import tf_inspect -def assert_close( - x, y, data=None, summarize=None, message=None, name="assert_close"): - """Assert that x and y are within machine epsilon of each other. - - Args: - x: Floating-point `Tensor` - y: Floating-point `Tensor` - data: The tensors to print out if the condition is `False`. Defaults to - error message and first few entries of `x` and `y`. - summarize: Print this many entries of each tensor. - message: A string to prefix to the default message. - name: A name for this operation (optional). - - Returns: - Op raising `InvalidArgumentError` if |x - y| > machine epsilon. - """ - message = message or "" - x = ops.convert_to_tensor(x, name="x") - y = ops.convert_to_tensor(y, name="y") - - if data is None: - data = [ - message, - "Condition x ~= y did not hold element-wise: x = ", x, "y = ", y - ] - - if x.dtype.is_integer: - return check_ops.assert_equal( - x, y, data=data, summarize=summarize, message=message, name=name) - - with ops.name_scope(name, "assert_close", [x, y, data]): - tol = np.finfo(x.dtype.as_numpy_dtype).eps - condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol)) - return control_flow_ops.Assert( - condition, data, summarize=summarize) - - def assert_integer_form( x, data=None, summarize=None, message=None, int_dtype=None, name="assert_integer_form"): @@ -241,8 +204,12 @@ def get_logits_and_probs(logits=None, dependencies = [check_ops.assert_non_negative(probs)] if multidimensional: probs = embed_check_categorical_event_shape(probs) - dependencies += [assert_close(math_ops.reduce_sum(probs, -1), one, - message="probs does not sum to 1.")] + dependencies += [ + check_ops.assert_near( + math_ops.reduce_sum(probs, -1), + one, + message="probs does not sum to 1.") + ] else: dependencies += [check_ops.assert_less_equal( probs, one, message="probs has components greater than 1.")] diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 7385cb758514e160efec61d731e734d1af126742..889a00190ed99ecf3da8ba753724409627ae42c6 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import contextlib +import sys import warnings import numpy as np @@ -36,6 +37,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops # pylint: disable=unused-import +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -47,12 +49,16 @@ from tensorflow.python.ops import logging_ops # pylint: disable=unused-import from tensorflow.python.ops import manip_grad # pylint: disable=unused-import from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_grad # pylint: disable=unused-import from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._gradients_impl = sys.modules[__name__] # pylint: disable=protected-access + # Warn the user if we convert a sparse representation to dense with at # least this number of elements. _LARGE_SPARSE_NUM_ELEMENTS = 100000000 @@ -125,32 +131,6 @@ def _MarkReachedOps(from_ops, reached_ops): queue.extend(output.consumers()) -def _GatherInputs(to_ops, reached_ops): - """List all inputs of to_ops that are in reached_ops. - - Args: - to_ops: list of Operations. - reached_ops: set of Operations. - - Returns: - The list of all inputs of to_ops that are in reached_ops. - That list includes all elements of to_ops. - """ - inputs = [] - queue = collections.deque() - queue.extend(to_ops) - while queue: - op = queue.popleft() - # We are interested in this op. - if op in reached_ops: - inputs.append(op) - # Clear the boolean so we won't add the inputs again. - reached_ops.remove(op) - for inp in op.inputs: - queue.append(inp.op) - return inputs - - def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops): """Initialize the pending count for ops between two lists of Operations. @@ -374,7 +354,11 @@ def _SymGrad(op, out_grads): f.name = op.type for k in op.node_def.attr: f.attr[k].CopyFrom(op.node_def.attr[k]) - in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) + # TODO(apassos) use a better dtype here + in_grads = functional_ops.symbolic_gradient( + input=f_in, + Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types], + f=f) return in_grads @@ -524,10 +508,10 @@ def gradients(ys, RuntimeError: if called in Eager mode. """ - # Creating the gradient graph for control flow mutates Operations. _lock - # ensures a Session.run call cannot occur between creating and mutating new - # ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # Creating the gradient graph for control flow mutates Operations. + # _mutation_lock ensures a Session.run call cannot occur between creating and + # mutating new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients) @@ -543,9 +527,8 @@ def _GradientsHelper(ys, src_graph=None): """Implementation of gradients().""" if context.executing_eagerly(): - raise RuntimeError("tf.gradients not supported when eager execution " - "is enabled. Use tf.contrib.eager.GradientTape " - "instead.") + raise RuntimeError("tf.gradients is not supported when eager execution " + "is enabled. Use tf.GradientTape instead.") if src_graph is None: src_graph = ops.get_default_graph() diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index d81c756f1cbc0a46d094066cda369067f7d3d1f6..d70cd088c9a03a45131d6f83663f36e8960f4bd9 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -57,90 +57,8 @@ from tensorflow.python.ops.nn_ops import bias_add from tensorflow.python.platform import googletest -def _OpsBetween(to_ops, from_ops): - """Build the list of operations between two lists of Operations. - - Args: - to_ops: list of Operations. - from_ops: list of Operations. - - Returns: - The list of operations between "from_ops" and "to_ops", sorted by - decreasing operation id. This list contains all elements of to_ops. - - TODO(touts): Think about returning an empty list if from_ops are not - reachable from to_ops. Presently it returns to_ops in that case. - """ - # Ops that are reachable from the output of "input_ops". - reached_ops = set() - # We only care to reach up to "output_ops" so we mark the - # output ops as reached to avoid recursing past them. - for op in to_ops: - reached_ops.add(op) - gradients_impl._MarkReachedOps(from_ops, reached_ops) - between_ops = gradients_impl._GatherInputs(to_ops, reached_ops) - between_ops.sort(key=lambda x: -x._id) - return between_ops - - class GradientsTest(test_util.TensorFlowTestCase): - def _OpNames(self, op_list): - return ["%s/%d" % (str(op.name), op._id) for op in op_list] - - def _assertOpListEqual(self, ops1, ops2): - self.assertEquals(self._OpNames(ops1), self._OpNames(ops2)) - - def testOpsBetweenSimple(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - # Full graph - self._assertOpListEqual([t3.op, t2.op, t1.op], - _OpsBetween([t3.op], [t1.op, t2.op])) - # Only t1, t3. - self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op])) - - def testOpsBetweenUnreachable(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - _ = array_ops.stack([t1, t2]) - t4 = constant(1.0) - t5 = constant(2.0) - t6 = array_ops.stack([t4, t5]) - # Elements of to_ops are always listed. - self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op])) - - def testOpsBetweenCut(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - t4 = constant([1.0]) - t5 = array_ops.concat([t4, t3], 0) - t6 = constant([2.0]) - t7 = array_ops.concat([t5, t6], 0) - self._assertOpListEqual([t7.op, t5.op, t4.op], - _OpsBetween([t7.op], [t4.op])) - - def testOpsBetweenCycle(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - t4 = array_ops.concat([t3, t3, t3], 0) - t5 = constant([1.0]) - t6 = array_ops.concat([t4, t5], 0) - t7 = array_ops.concat([t6, t3], 0) - self._assertOpListEqual([t6.op, t4.op, t3.op], - _OpsBetween([t6.op], [t3.op])) - self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op], - _OpsBetween([t7.op], [t1.op, t5.op])) - self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op], - _OpsBetween([t6.op], [t2.op, t5.op])) - def testGradients(self): with ops.Graph().as_default(): inp = constant(1.0, shape=[32, 100], name="in") diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 2c7751f7923dca4d0c4f907a673b06ba86b9f342..a2eae452ae551eb1792e5b21477d31c55d64fd79 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -57,6 +57,7 @@ ops.NotDifferentiable('NonMaxSuppression') ops.NotDifferentiable('NonMaxSuppressionV2') +# pylint: disable=invalid-name def _assert(cond, ex_type, msg): """A polymorphic assert, works with tensors and boolean expressions. @@ -1070,15 +1071,16 @@ def resize_images(images, @tf_export('image.resize_image_with_pad') -def resize_image_with_pad(image, target_height, target_width, +def resize_image_with_pad(image, + target_height, + target_width, method=ResizeMethod.BILINEAR): - """ - Resizes and pads an image to a target width and height. + """Resizes and pads an image to a target width and height. Resizes an image to a target width and height by keeping the aspect ratio the same without distortion. If the target dimensions don't match the image dimensions, the image - is resized and then padded with zeroes to match requested + is resized and then padded with zeroes to match requested dimensions. Args: @@ -1139,10 +1141,10 @@ def resize_image_with_pad(image, target_height, target_width, ratio = max_(f_width / f_target_width, f_height / f_target_height) resized_height_float = f_height / ratio resized_width_float = f_width / ratio - resized_height = math_ops.cast(math_ops.floor(resized_height_float), - dtype=dtypes.int32) - resized_width = math_ops.cast(math_ops.floor(resized_width_float), - dtype=dtypes.int32) + resized_height = math_ops.cast( + math_ops.floor(resized_height_float), dtype=dtypes.int32) + resized_width = math_ops.cast( + math_ops.floor(resized_width_float), dtype=dtypes.int32) padding_height = (f_target_height - resized_height_float) / 2 padding_width = (f_target_width - resized_width_float) / 2 @@ -1154,13 +1156,13 @@ def resize_image_with_pad(image, target_height, target_width, # Resize first, then pad to meet requested dimensions resized = resize_images(image, [resized_height, resized_width], method) - padded = pad_to_bounding_box(resized, p_height, p_width, - target_height, target_width) + padded = pad_to_bounding_box(resized, p_height, p_width, target_height, + target_width) if padded.get_shape().ndims is None: raise ValueError('padded contains no shape.') - _, padded_height, padded_width, _ = _ImageDimensions(padded, rank=4) + _ImageDimensions(padded, rank=4) if not is_batch: padded = array_ops.squeeze(padded, squeeze_dims=[0]) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 8e40de140df632c9b458c2e2b8a673925ab13634..cf9761803bf9654e21ec12e1f1c7193b3e88c020 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -2731,7 +2731,7 @@ class ResizeImageWithPadTest(test_util.TensorFlowTestCase): try: self._ResizeImageWithPad(x, target_height, target_width, use_tensor_inputs) - except Exception as e: + except Exception as e: # pylint: disable=broad-except if err_msg not in str(e): raise else: diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 724fcc39cddbdd2a8acd9c0bbaa7b968c6d1510d..5bfc5ce2a7a1913b097ee67d1b18d684b5ebcaa5 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -43,7 +43,8 @@ from tensorflow.python.ops import linalg_ops_impl from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.deprecation import ( + deprecated, deprecated_arg_values) from tensorflow.python.util.tf_export import tf_export @@ -409,8 +410,10 @@ class UniformUnitScaling(Initializer): class VarianceScaling(Initializer): """Initializer capable of adapting its scale to the shape of weights tensors. - With `distribution="normal"`, samples are drawn from a truncated normal - distribution centered on zero, with `stddev = sqrt(scale / n)` + With `distribution="truncated_normal" or "untruncated_normal"`, + samples are drawn from a truncated/untruncated normal + distribution with a mean of zero and a standard deviation (after truncation, + if used) `stddev = sqrt(scale / n)` where n is: - number of input units in the weight tensor, if mode = "fan_in" - number of output units, if mode = "fan_out" @@ -433,10 +436,14 @@ class VarianceScaling(Initializer): "distribution" arguments. """ + @deprecated_arg_values( + None, + "`normal` is a deprecated alias for `truncated_normal`", + distribution="normal") def __init__(self, scale=1.0, mode="fan_in", - distribution="normal", + distribution="truncated_normal", seed=None, dtype=dtypes.float32): if scale <= 0.: @@ -444,7 +451,8 @@ class VarianceScaling(Initializer): if mode not in {"fan_in", "fan_out", "fan_avg"}: raise ValueError("Invalid `mode` argument:", mode) distribution = distribution.lower() - if distribution not in {"normal", "uniform"}: + if distribution not in {"normal", "uniform", + "truncated_normal", "untruncated_normal"}: raise ValueError("Invalid `distribution` argument:", distribution) self.scale = scale self.mode = mode @@ -466,11 +474,15 @@ class VarianceScaling(Initializer): scale /= max(1., fan_out) else: scale /= max(1., (fan_in + fan_out) / 2.) - if self.distribution == "normal": + if self.distribution == "normal" or self.distribution == "truncated_normal": # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = math.sqrt(scale) / .87962566103423978 return random_ops.truncated_normal( shape, 0.0, stddev, dtype, seed=self.seed) + elif self.distribution == "untruncated_normal": + stddev = math.sqrt(scale) + return random_ops.random_normal( + shape, 0.0, stddev, dtype, seed=self.seed) else: limit = math.sqrt(3.0 * scale) return random_ops.random_uniform( @@ -551,7 +563,9 @@ class ConvolutionDeltaOrthogonal(Initializer): The shape of the tensor must have length 3, 4 or 5. The number of input filters must not exceed the number of output filters. The center pixels of the - tensor form an orthogonal matrix. Other pixels are set to be zero. + tensor form an orthogonal matrix. Other pixels are set to be zero. See + algorithm 2 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 + Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. @@ -672,6 +686,7 @@ class ConvolutionOrthogonal2D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. @@ -807,6 +822,7 @@ class ConvolutionOrthogonal1D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. @@ -923,6 +939,7 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 [Xiao et al., 2018] in: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index 1b5bb9470c4406ad075f2f6d5c38661311472727..78c85db557047ebcc3dd655deae62acbcef929c7 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -102,7 +102,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): raise NotImplementedError("operator_build_infos has not been implemented.") @abc.abstractmethod - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): """Build a batch matrix and an Operator that should have similar behavior. Every operator acts like a (batch) matrix. This method returns both @@ -118,9 +118,6 @@ class LinearOperatorDerivedClassTest(test.TestCase): Returns: operator: `LinearOperator` subclass instance. mat: `Tensor` representing operator. - feed_dict: Dictionary. - If placholder is True, this must contains everything needed to be fed - to sess.run calls at runtime to make the operator work. """ # Create a matrix as a numpy array with desired shape/dtype. # Create a LinearOperator that should have the same behavior as the matrix. @@ -189,12 +186,12 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_dense = operator.to_dense() if not use_placeholder: self.assertAllEqual(build_info.shape, op_dense.get_shape()) - op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict) + op_dense_v, mat_v = sess.run([op_dense, mat]) self.assertAC(op_dense_v, mat_v) def test_det(self): @@ -204,14 +201,13 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_det = operator.determinant() if not use_placeholder: self.assertAllEqual(build_info.shape[:-2], op_det.get_shape()) op_det_v, mat_det_v = sess.run( - [op_det, linalg_ops.matrix_determinant(mat)], - feed_dict=feed_dict) + [op_det, linalg_ops.matrix_determinant(mat)]) self.assertAC(op_det_v, mat_det_v) def test_log_abs_det(self): @@ -221,7 +217,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_log_abs_det = operator.log_abs_determinant() _, mat_log_abs_det = linalg.slogdet(mat) @@ -229,7 +225,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): self.assertAllEqual( build_info.shape[:-2], op_log_abs_det.get_shape()) op_log_abs_det_v, mat_log_abs_det_v = sess.run( - [op_log_abs_det, mat_log_abs_det], feed_dict=feed_dict) + [op_log_abs_det, mat_log_abs_det]) self.assertAC(op_log_abs_det_v, mat_log_abs_det_v) def _test_matmul(self, with_batch): @@ -246,7 +242,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) x = self._make_x( operator, adjoint=adjoint, with_batch=with_batch) @@ -264,7 +260,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): self.assertAllEqual(op_matmul.get_shape(), mat_matmul.get_shape()) op_matmul_v, mat_matmul_v = sess.run( - [op_matmul, mat_matmul], feed_dict=feed_dict) + [op_matmul, mat_matmul]) self.assertAC(op_matmul_v, mat_matmul_v) def test_matmul(self): @@ -289,7 +285,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) rhs = self._make_rhs( operator, adjoint=adjoint, with_batch=with_batch) @@ -307,8 +303,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): if not use_placeholder: self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) - op_solve_v, mat_solve_v = sess.run( - [op_solve, mat_solve], feed_dict=feed_dict) + op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) self.assertAC(op_solve_v, mat_solve_v) def test_solve(self): @@ -326,14 +321,13 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape()) - op_trace_v, mat_trace_v = sess.run( - [op_trace, mat_trace], feed_dict=feed_dict) + op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace]) self.assertAC(op_trace_v, mat_trace_v) def test_add_to_tensor(self): @@ -343,15 +337,14 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_plus_2mat = operator.add_to_tensor(2 * mat) if not use_placeholder: self.assertAllEqual(build_info.shape, op_plus_2mat.get_shape()) - op_plus_2mat_v, mat_v = sess.run( - [op_plus_2mat, mat], feed_dict=feed_dict) + op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat]) self.assertAC(op_plus_2mat_v, 3 * mat_v) @@ -362,7 +355,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_diag_part = operator.diag_part() mat_diag_part = array_ops.matrix_diag_part(mat) @@ -372,7 +365,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): op_diag_part.get_shape()) op_diag_part_, mat_diag_part_ = sess.run( - [op_diag_part, mat_diag_part], feed_dict=feed_dict) + [op_diag_part, mat_diag_part]) self.assertAC(op_diag_part_, mat_diag_part_) diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index de9b3c6909ddd9c22ac4bced5ec48e4de354bd19..66633c8b12f60c86760f906aa8e4312c7394e796 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -192,6 +192,11 @@ def compute_weighted_loss( on some model parameters but you do not want this to affect the loss gradient, you need to apply @{tf.stop_gradient} to `weights` before passing them to `compute_weighted_loss`. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ Reduction.validate(reduction) with ops.name_scope(scope, "weighted_loss", (losses, weights)): @@ -260,6 +265,11 @@ def absolute_difference( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid or if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -306,6 +316,11 @@ def cosine_distance( Raises: ValueError: If `predictions` shape doesn't match `labels` shape, or `axis`, `labels`, `predictions` or `weights` is `None`. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ axis = deprecated_argument_lookup("axis", axis, "dim", dim) if axis is None: @@ -353,6 +368,11 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, Raises: ValueError: If the shapes of `logits` and `labels` don't match or if `labels` or `logits` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -416,6 +436,11 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -477,6 +502,11 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -540,6 +570,11 @@ def mean_pairwise_squared_error( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -618,6 +653,11 @@ def mean_squared_error( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -670,6 +710,11 @@ def sigmoid_cross_entropy( ValueError: If the shape of `logits` doesn't match that of `multi_class_labels` or if the shape of `weights` is invalid, or if `weights` is None. Also if `multi_class_labels` or `logits` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if multi_class_labels is None: raise ValueError("multi_class_labels must not be None.") @@ -731,6 +776,11 @@ def softmax_cross_entropy( ValueError: If the shape of `logits` doesn't match that of `onehot_labels` or if the shape of `weights` is invalid or if `weights` is None. Also if `onehot_labels` or `logits` is None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if onehot_labels is None: raise ValueError("onehot_labels must not be None.") @@ -828,7 +878,8 @@ def sparse_softmax_cross_entropy( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or + `float64`. weights: Coefficients for the loss. This must be scalar or broadcastable to `labels` (i.e. same rank and each dimension is either 1 or the same). scope: the scope for the operations performed in computing the loss. @@ -842,6 +893,11 @@ def sparse_softmax_cross_entropy( Raises: ValueError: If the shapes of `logits`, `labels`, and `weights` are incompatible, or if any of them are None. + + @compatbility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py index 10646af8a983f149cf0620bf355cf0bc1fa697fb..97bba46661d056fd336c68988e3bc17ef4232487 100644 --- a/tensorflow/python/ops/losses/util.py +++ b/tensorflow/python/ops/losses/util.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -32,7 +33,10 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): loss: A loss `Tensor`. loss_collection: Optional collection to add the loss to. """ - if loss_collection: + # Since we have no way of figuring out when a training iteration starts or + # ends, holding on to a loss when executing eagerly is indistingishable from + # leaking memory. We instead leave the collection empty. + if loss_collection and not context.executing_eagerly(): ops.add_to_collection(loss_collection, loss) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index a48b3c9395f48ecaf3879b95199f69c84f7f095a..f0c6bd532fcdb76922ce4d5aa7fa13936db81b2f 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -651,27 +651,28 @@ def _BesselI1eGrad(op, grad): @ops.RegisterGradient("Igamma") def _IgammaGrad(op, grad): - """Returns gradient of igamma(a, x) with respect to x.""" - # TODO(ebrevdo): Perhaps add the derivative w.r.t. a + """Returns gradient of igamma(a, x) with respect to a and x.""" a = op.inputs[0] x = op.inputs[1] sa = array_ops.shape(a) sx = array_ops.shape(x) - unused_ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) + ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) - # Perform operations in log space before summing, because Gamma(a) - # and Gamma'(a) can grow large. - partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) - # TODO(b/36815900): Mark None return values as NotImplemented - return (None, array_ops.reshape( - math_ops.reduce_sum(partial_x * grad, rx), sx)) + with ops.control_dependencies([grad]): + partial_a = gen_math_ops.igamma_grad_a(a, x) + # Perform operations in log space before summing, because Gamma(a) + # and Gamma'(a) can grow large. + partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) + - math_ops.lgamma(a)) + return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Igammac") def _IgammacGrad(op, grad): - """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. x.""" - _, igamma_grad_x = _IgammaGrad(op, grad) - return None, -igamma_grad_x + """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" + igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad) + return (-igamma_grad_a, -igamma_grad_x) @ops.RegisterGradient("Betainc") diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 466d0dadc8a430861ab27b6a522ca6acd2db7855..cdb6dc8f22919420ff44e217578315d17cb93d8c 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1990,7 +1990,7 @@ def matmul(a, sparse_matmul_types = [dtypes.bfloat16, dtypes.float32] use_sparse_matmul = ( a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types) - if (a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16 and + if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and a.dtype != b.dtype): # matmul currently doesn't handle mixed-precision inputs. use_sparse_matmul = True diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index c807c8bc2efbf867700f0df37783b02fefa0ca82..6b709e5e7faf0a74f966f446ba9d33ee1087908a 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -37,14 +37,14 @@ log = np.log class ReduceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceAllDims(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): y_tf = self.evaluate(math_ops.reduce_sum(x)) self.assertEqual(y_tf, 21) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceExplicitAxes(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): @@ -57,7 +57,7 @@ class ReduceTest(test_util.TensorFlowTestCase): for axis in (None, (0, 1), (-1, -2), (-2, -1, 0, 1)): self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceInvalidAxis(self): if context.executing_eagerly(): # The shape check is in run a graph construction time. In eager mode, @@ -150,7 +150,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase): class RoundTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRounding(self): x = np.arange(-5.0, 5.0, .25) for dtype in [np.float32, np.double, np.int32]: @@ -194,7 +194,7 @@ class ModTest(test_util.TensorFlowTestCase): class SquaredDifferenceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSquaredDifference(self): for dtype in [np.int32, np.float16]: x = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) @@ -207,7 +207,7 @@ class SquaredDifferenceTest(test_util.TensorFlowTestCase): class ApproximateEqualTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testApproximateEqual(self): for dtype in [np.float32, np.double]: x = dtype(1) @@ -237,8 +237,8 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): def testApproximateEqualShape(self): for dtype in [np.float32, np.double]: - x = np.array([1, 2], dtype=np.float32) - y = np.array([[1, 2]], dtype=np.float32) + x = np.array([1, 2], dtype=dtype) + y = np.array([[1, 2]], dtype=dtype) # The inputs 'x' and 'y' must have the same shape. with self.assertRaisesRegexp( ValueError, "Shapes must be equal rank, but are 1 and 2"): @@ -247,7 +247,7 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): class ScalarMulTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsRefs(self): if context.executing_eagerly(): var = resource_variable_ops.ResourceVariable(10, name="var") @@ -259,14 +259,14 @@ class ScalarMulTest(test_util.TensorFlowTestCase): self.evaluate(init) self.assertEqual(30, self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsConstant(self): const = constant_op.constant(10) result = math_ops.scalar_mul(3, const) with test_util.device(use_gpu=True): self.assertEqual(30, self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsTensor(self): tensor = array_ops.ones([10, 10]) result = math_ops.scalar_mul(3, tensor) @@ -275,7 +275,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase): with test_util.device(use_gpu=True): self.assertAllEqual(self.evaluate(expected), self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsIndexedSlices(self): values = constant_op.constant([2, 3, 5, 7, 0, -1], shape=[3, 2]) indices = constant_op.constant([0, 2, 5]) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 0c2f5b06c497e8ca7db20ac09938c86b425d66a0..41d54a6c2f9d8cd961cea398da679fd81361b848 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2009,7 +2009,8 @@ def sparse_softmax_cross_entropy_with_logits( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or + `float64`. name: A name for the operation (optional). Returns: @@ -2166,7 +2167,7 @@ def _calc_conv_flops(graph, node): filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) filter_in_depth = int(filter_shape[2]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats( "flops", (output_count * filter_in_depth * filter_height * filter_width * 2)) @@ -2184,7 +2185,7 @@ def _calc_depthwise_conv_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) @@ -2594,7 +2595,7 @@ def _calc_dilation2d_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 035b4735affbd37f9de94057eed6f7b5d9aadd6e..ae24ca0552e7ba2823ec9404ecc848f510cce464 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -76,7 +76,7 @@ class SoftmaxTest(test_lib.TestCase): z = u.sum(1)[:, np.newaxis] return u / z - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSoftmax(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -123,7 +123,7 @@ class LogPoissonLossTest(test_lib.TestCase): lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) return lpl - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogPoissonLoss(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -164,7 +164,7 @@ class LogSoftmaxTest(test_lib.TestCase): u = x - m return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogSoftmax(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -201,7 +201,7 @@ class LogSoftmaxTest(test_lib.TestCase): class L2LossTest(test_lib.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2Loss(self): for dtype in [dtypes.float32, dtypes.float64]: x = constant_op.constant( @@ -235,7 +235,7 @@ class L2NormalizeTest(test_lib.TestCase): norm = np.apply_along_axis(np.linalg.norm, dim, x) return x / np.expand_dims(norm, dim) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2Normalize(self): x_shape = [20, 7, 3] np.random.seed(1) @@ -246,7 +246,7 @@ class L2NormalizeTest(test_lib.TestCase): y_tf = nn_impl.l2_normalize(x_tf, dim) self.assertAllClose(y_np, self.evaluate(y_tf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2NormalizeDimArray(self): x_shape = [20, 7, 3] np.random.seed(1) diff --git a/tensorflow/python/ops/random_grad.py b/tensorflow/python/ops/random_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..baa8e2e2cd33d37312b5b14bea3c248c06ff2e50 --- /dev/null +++ b/tensorflow/python/ops/random_grad.py @@ -0,0 +1,65 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradients for operators defined in random_ops.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_random_ops +from tensorflow.python.ops import math_ops + + +def add_leading_unit_dimensions(x, num_dimensions): + new_shape = array_ops.concat( + [array_ops.ones([num_dimensions], dtype=dtypes.int32), + array_ops.shape(x)], axis=0) + return array_ops.reshape(x, new_shape) + + +@ops.RegisterGradient("RandomGamma") +def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name + """Returns the gradient of a Gamma sample w.r.t. alpha. + + The gradient is computed using implicit differentiation, see + "Implicit Reparameterization Gradients" (https://arxiv.org/abs/1805.08498). + + Args: + op: A `RandomGamma` operation. We assume that the inputs to the operation + are `shape` and `alpha` tensors, and the output is the `sample` tensor. + grad: The incoming gradient `dloss / dsample` of the same shape as + `op.outputs[0]`. + + Returns: + A `Tensor` with derivatives `dloss / dalpha` + """ + shape = op.inputs[0] + alpha = op.inputs[1] + sample = op.outputs[0] + + with ops.control_dependencies([grad]): + # Make the parameters alpha broadcastable with samples by appending + # unit dimensions. + num_sample_dimensions = array_ops.shape(shape)[0] + alpha_broadcastable = add_leading_unit_dimensions( + alpha, num_sample_dimensions) + partial_a = gen_random_ops.random_gamma_grad(alpha_broadcastable, sample) + + # The first input is shape; the second input is alpha. + return (None, math_ops.reduce_sum( + grad * partial_a, axis=math_ops.range(num_sample_dimensions))) diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 6a2dd3f1cd55eea1d3b652a31cd2784c411c2ce0..b8738adf66e6ff51962ed44dce7cd4b95544e271 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -368,25 +368,41 @@ def random_gamma(shape, `alpha` is the shape parameter describing the distribution(s), and `beta` is the inverse scale parameter(s). - Example: + Note: Because internal calculations are done using `float64` and casting has + `floor` semantics, we must manually map zero outcomes to the smallest + possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This + means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise + should. This bias can only happen for small values of `alpha`, i.e., + `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. - samples = tf.random_gamma([10], [0.5, 1.5]) - # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents - # the samples drawn from each distribution + The samples are differentiable w.r.t. alpha and beta. + The derivatives are computed using the approach described in the paper - samples = tf.random_gamma([7, 5], [0.5, 1.5]) - # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] - # represents the 7x5 samples drawn from each of the two distributions + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) - samples = tf.random_gamma([30], [[1.],[3.],[5.]], beta=[[3., 4.]]) - # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. + Example: - Note: Because internal calculations are done using `float64` and casting has - `floor` semantics, we must manually map zero outcomes to the smallest - possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This - means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise - should. This bias can only happen for small values of `alpha`, i.e., - `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. + ```python + samples = tf.random_gamma([10], [0.5, 1.5]) + # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents + # the samples drawn from each distribution + + samples = tf.random_gamma([7, 5], [0.5, 1.5]) + # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] + # represents the 7x5 samples drawn from each of the two distributions + + alpha = tf.constant([[1.],[3.],[5.]]) + beta = tf.constant([[3., 4.]]) + samples = tf.random_gamma([30], alpha=alpha, beta=beta) + # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. + + loss = tf.reduce_mean(tf.square(samples)) + dloss_dalpha, dloss_dbeta = tf.gradients(loss, [alpha, beta]) + # unbiased stochastic derivatives of the loss function + alpha.shape == dloss_dalpha.shape # True + beta.shape == dloss_dbeta.shape # True + ``` Args: shape: A 1-D integer Tensor or Python array. The shape of the output samples @@ -406,8 +422,9 @@ def random_gamma(shape, name: Optional name for the operation. Returns: - samples: a `Tensor` of shape `tf.concat(shape, tf.shape(alpha + beta))` - with values of type `dtype`. + samples: a `Tensor` of shape + `tf.concat([shape, tf.shape(alpha + beta)], axis=0)` with values of type + `dtype`. """ with ops.name_scope(name, "random_gamma", [shape, alpha, beta]): shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) @@ -421,8 +438,6 @@ def random_gamma(shape, gen_random_ops.random_gamma( shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta) -ops.NotDifferentiable("RandomGamma") - @tf_export("random_poisson") def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): @@ -432,13 +447,15 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): Example: - samples = tf.random_poisson([0.5, 1.5], [10]) - # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents - # the samples drawn from each distribution + ```python + samples = tf.random_poisson([0.5, 1.5], [10]) + # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents + # the samples drawn from each distribution - samples = tf.random_poisson([12.2, 3.3], [7, 5]) - # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] - # represents the 7x5 samples drawn from each of the two distributions + samples = tf.random_poisson([12.2, 3.3], [7, 5]) + # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] + # represents the 7x5 samples drawn from each of the two distributions + ``` Args: lam: A Tensor or Python value or N-D array of type `dtype`. @@ -455,8 +472,8 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): name: Optional name for the operation. Returns: - samples: a `Tensor` of shape `tf.concat(shape, tf.shape(lam))` with - values of type `dtype`. + samples: a `Tensor` of shape `tf.concat([shape, tf.shape(lam)], axis=0)` + with values of type `dtype`. """ with ops.name_scope(name, "random_poisson", [lam, shape]): shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index de44a3e848586d92b3b3155edfbfcadc47755089..15cafbbde50335de0dc0cd8849425c07b4ac81d3 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -851,14 +851,15 @@ class ResourceVariable(variables.Variable): operator: string. The operator name. """ + tensor_oper = getattr(ops.Tensor, operator) def _run_op(a, *args): # pylint: disable=protected-access value = a._AsTensor() - return getattr(ops.Tensor, operator)(value, *args) + return tensor_oper(value, *args) # Propagate __doc__ to wrapper try: - _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ + _run_op.__doc__ = tensor_oper.__doc__ except AttributeError: pass @@ -998,32 +999,28 @@ class ResourceVariable(variables.Variable): def __imul__(self, unused_other): raise RuntimeError("Variable *= value not supported. Use " - "variable.assign_mul(value) to modify the variable " - "value and variable = variable * value to get a new " - "Tensor object.") + "`var.assign(var * value)` to modify the variable or " + "`var = var * value` to get a new Tensor object.") def __idiv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __itruediv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __irealdiv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __ipow__(self, unused_other): raise RuntimeError("Variable **= value not supported. Use " - "value and variable = variable ** value to get a new " - "Tensor object.") + "`var.assign(var ** value)` to modify the variable or " + "`var = var ** value` to get a new Tensor object.") pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 10d576c95bc4fd3147da44ee1522dc829bcab83d..deba133fb9910f28c7f902f334174734c3c742f7 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops @@ -131,6 +132,18 @@ def _maybe_tensor_shape_from_tensor(shape): return shape +def _should_cache(): + """Returns True if a default caching device should be set, otherwise False.""" + if context.executing_eagerly(): + return False + # Don't set a caching device when running in a loop, since it is possible that + # train steps could be wrapped in a tf.while_loop. In that scenario caching + # prevents forward computations in loop iterations from re-reading the + # updated weights. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingWhileContext(ctxt) is None + + # pylint: disable=unused-argument def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, @@ -558,7 +571,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -828,7 +841,8 @@ def _dynamic_rnn_loop(cell, final_outputs = nest.pack_sequence_as( structure=cell.output_size, flat_sequence=final_outputs) if not in_graph_mode: - final_outputs = array_ops.stack(final_outputs, axis=0) + final_outputs = nest.map_structure_up_to( + cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) return (final_outputs, final_state) @@ -1014,7 +1028,7 @@ def raw_rnn(cell, loop_fn, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -1227,7 +1241,7 @@ def static_rnn(cell, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 05723c6960af3772d9576756ee94bd19f562edd1..82a044a0d4c8710f5ade0aa460f4354a0dd35deb 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -1331,7 +1332,7 @@ class MultiRNNCell(RNNCell): return cur_inp, new_states -class _SlimRNNCell(RNNCell, checkpointable.NotCheckpointable): +class _SlimRNNCell(RNNCell, checkpointable_tracking.NotCheckpointable): """A simple wrapper for slim.rnn_cells.""" def __init__(self, cell_fn): diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index cc23d0d133ecdb1415e0effcbc2ce52a962fb41e..1e3f662ff34f67d2b5f226427c8a03d82b9f2a7c 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_script_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import compat @@ -96,28 +97,27 @@ class EagerFunc(object): return constant_op.constant(0.0, dtype=dtype) return ops.convert_to_tensor(value, dtype=dtype) - def __call__(self, on_gpu, token, args): + def __call__(self, device, token, args): """Passes `args` to `self._func`, which is executed eagerly.""" - with context.eager_mode(): - with backprop.GradientTape() as tape: - for tensor in args: - tape.watch(tensor) - ret = self._func(*args) - # NB: The tape needs to watch copies across devices. - maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() + with context.eager_mode(), backprop.GradientTape() as tape: + for tensor in args: + tape.watch(tensor) + ret = self._func(*args) + # Use tf.identity to copy the returned tensors to device if neccesary. + with ops.device(device): if isinstance(ret, (tuple, list)): outputs = [ - maybe_copy_to_gpu(self._convert(x, dtype=dtype)) + array_ops.identity(self._convert(x, dtype=dtype)) for (x, dtype) in zip(ret, self._out_dtypes) ] elif ret is None: outputs = None else: - outputs = maybe_copy_to_gpu( + outputs = array_ops.identity( self._convert(ret, dtype=self._out_dtypes[0])) - tape_cache[compat.as_bytes(token)] = (tape, args, outputs) - return outputs + tape_cache[compat.as_bytes(token)] = (tape, args, outputs) + return outputs class FuncRegistry(object): @@ -174,14 +174,14 @@ class FuncRegistry(object): else: return result - def __call__(self, token, on_gpu, args): + def __call__(self, token, device, args): """Calls the registered function for `token` with args. Args: token: A key into this `FuncRegistry` identifying which function to call. - on_gpu: A boolean indicating whether or not `token`'s corresponding - operation was placed on GPU; only used if the function registered for - `token` is an `EagerPyFunc`. + device: Name of the device on which outputs of `token`'s corresponding + operation should be placed. Used iff the function registered for `token` + is an EagerPyFunc. args: The arguments to pass to the function registered for `token`. Returns: @@ -201,7 +201,7 @@ class FuncRegistry(object): # or if the graph is being driven by concurrent session.run() calls. # # TODO(akshayka): Key the tape cache in a thread-safe way. - return func(on_gpu, token, args) + return func(device, token, args) else: ret = func(*args) # Strings seem to lead to a memory leak here if they're not wrapped in a @@ -232,8 +232,13 @@ _py_funcs = FuncRegistry() pywrap_tensorflow.InitializePyTrampoline(_py_funcs) -def _internal_py_func(func, inp, Tout, stateful=None, eager=False, - is_grad_func=False, name=None): +def _internal_py_func(func, + inp, + Tout, + stateful=None, + eager=False, + is_grad_func=False, + name=None): """See documentation for py_func and eager_py_func.""" is_list_or_tuple = False @@ -296,7 +301,8 @@ def _EagerPyFuncGrad(op, dy): func=eagerly_executed_grad, inp=[dy] if isinstance(dy, ops.Tensor) else dy, Tout=[tensor.dtype for tensor in op.inputs], - eager=True, is_grad_func=True) + eager=True, + is_grad_func=True) def eager_py_func(func, inp, Tout, name=None): @@ -337,7 +343,7 @@ def eager_py_func(func, inp, Tout, name=None): or print statements as desired, and wrap those functions in `tf.contrib.eager.py_func`. - For more information on eager execution, see @{$programmers_guide/eager}. + For more information on eager execution, see @{$guide/eager}. `tf.contrib.eager.py_func` is similar in spirit to @{tf.py_func}, but unlike the latter, the former lets you use TensorFlow operations in the wrapped diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 97353d6c747cb7e4d3c1fa92ad61af24fb17de91..1223b290ff6cfcfba27f40c05556c85b59e77148 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -116,6 +116,35 @@ def _SparseReduceSumGrad(op, out_grad): None, None) +@ops.RegisterGradient("SparseSlice") +def _SparseSliceGrad(op, *grads): + """The backward operator for the SparseSlice op. + + This op takes in the upstream gradient w.r.t. non-empty values of + the sliced `SparseTensor`, and outputs the gradients w.r.t. + the non-empty values of input `SparseTensor`. + + Args: + op: the SparseSlice op + *grads: the incoming gradients, one element per output of `op` + + Returns: + Gradient for each of the 5 input tensors of SparseSlice: + (indices, values, shape, start, size) + The gradients for the indices, shape, start and the size are None. + """ + backprop_val_grad = grads[1] + input_indices = op.inputs[0] + input_start = op.inputs[3] + output_indices = op.outputs[0] + + val_grad = gen_sparse_ops.sparse_slice_grad( + backprop_val_grad, input_indices, input_start, output_indices) + val_grad.set_shape(op.inputs[1].get_shape()) + # (indices, values, shape, start, size) + return (None, val_grad, None, None, None) + + @ops.RegisterGradient("SparseTensorDenseMatMul") def _SparseTensorDenseMatMulGrad(op, grad): """Gradients for the dense tensor in the SparseTensorDenseMatMul op. diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 1508873b751c4fa42d3488ff2d18b5795fda9652..9a10abfcf736be783bfcd7907ec6f357912828ab 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -34,7 +34,7 @@ from tensorflow.python.util.tf_export import tf_export # TODO(b/27419586) Change docstring for required dtype of x once int allowed @tf_export('lbeta') -def lbeta(x, name='lbeta'): +def lbeta(x, name=None): r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define @@ -64,7 +64,7 @@ def lbeta(x, name='lbeta'): # This is consistent with a convention that the sum over the empty set 0, and # the product is 1. # This is standard. See https://en.wikipedia.org/wiki/Empty_set. - with ops.name_scope(name, values=[x]): + with ops.name_scope(name, 'lbeta', [x]): x = ops.convert_to_tensor(x, name='x') # Note reduce_sum([]) = 0. @@ -83,7 +83,7 @@ def lbeta(x, name='lbeta'): @tf_export('math.bessel_i0') -def bessel_i0(x, name='bessel_i0'): +def bessel_i0(x, name=None): """Computes the Bessel i0 function of `x` element-wise. Modified Bessel function of order 0. @@ -102,12 +102,12 @@ def bessel_i0(x, name='bessel_i0'): Equivalent to scipy.special.i0 @end_compatibility """ - with ops.name_scope(name, [x]): + with ops.name_scope(name, 'bessel_i0', [x]): return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x) @tf_export('math.bessel_i1') -def bessel_i1(x, name='bessel_i1'): +def bessel_i1(x, name=None): """Computes the Bessel i1 function of `x` element-wise. Modified Bessel function of order 1. @@ -126,7 +126,7 @@ def bessel_i1(x, name='bessel_i1'): Equivalent to scipy.special.i1 @end_compatibility """ - with ops.name_scope(name, [x]): + with ops.name_scope(name, 'bessel_i1', [x]): return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x) @@ -201,8 +201,8 @@ def einsum(equation, *inputs, **kwargs): indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ - equation = equation.replace(" ", "") - + equation = equation.replace(' ', '') + name = kwargs.pop('name', None) if kwargs: raise TypeError('invalid keyword arguments for this function: ' + ', '.join( diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index b7e164f149a9cca336fee061ae2cc3a464ca6132..9bc4098d5b63c3e8ee4f9c14332e65b3d2875d8b 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -25,24 +25,25 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging - class LBetaTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def test_one_dimensional_arg(self): # Should evaluate to 1 and 1/2. x_one = [1, 1.] x_one_half = [2, 1.] with self.test_session(use_gpu=True): - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_one)).eval()) - self.assertAllClose(0.5, - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one)))) + self.assertAllClose( + 0.5, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual([], special_math_ops.lbeta(x_one).get_shape()) def test_one_dimensional_arg_dynamic(self): @@ -53,7 +54,8 @@ class LBetaTest(test.TestCase): ph = array_ops.placeholder(dtypes.float32) beta_ph = math_ops.exp(special_math_ops.lbeta(ph)) self.assertAllClose(1, beta_ph.eval(feed_dict={ph: x_one})) - self.assertAllClose(0.5, beta_ph.eval(feed_dict={ph: x_one_half})) + self.assertAllClose(0.5, + beta_ph.eval(feed_dict={ph: x_one_half})) def test_four_dimensional_arg_with_partial_shape_dynamic(self): x_ = np.ones((3, 2, 3, 4)) @@ -66,15 +68,17 @@ class LBetaTest(test.TestCase): with self.test_session(use_gpu=True): x_ph = array_ops.placeholder(dtypes.float32, [3, 2, 3, None]) beta_ph = math_ops.exp(special_math_ops.lbeta(x_ph)) - self.assertAllClose(expected_beta_x, beta_ph.eval(feed_dict={x_ph: x_})) + self.assertAllClose(expected_beta_x, + beta_ph.eval(feed_dict={x_ph: x_})) + @test_util.run_in_graph_and_eager_modes def test_two_dimensional_arg(self): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose([0.5, 0.5], - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + [0.5, 0.5], + self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape()) def test_two_dimensional_arg_dynamic(self): @@ -83,50 +87,59 @@ class LBetaTest(test.TestCase): with self.test_session(use_gpu=True): ph = array_ops.placeholder(dtypes.float32) beta_ph = math_ops.exp(special_math_ops.lbeta(ph)) - self.assertAllClose([0.5, 0.5], beta_ph.eval(feed_dict={ph: x_one_half})) + self.assertAllClose([0.5, 0.5], + beta_ph.eval(feed_dict={ph: x_one_half})) + @test_util.run_in_graph_and_eager_modes def test_two_dimensional_proper_shape(self): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose([0.5, 0.5], - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + [0.5, 0.5], + self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual( (2,), - array_ops.shape(special_math_ops.lbeta(x_one_half)).eval()) + self.evaluate(array_ops.shape(special_math_ops.lbeta(x_one_half)))) self.assertEqual( tensor_shape.TensorShape([2]), special_math_ops.lbeta(x_one_half).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_complicated_shape(self): with self.test_session(use_gpu=True): x = ops.convert_to_tensor(np.random.rand(3, 2, 2)) - self.assertAllEqual((3, 2), - array_ops.shape(special_math_ops.lbeta(x)).eval()) + self.assertAllEqual( + (3, 2), self.evaluate(array_ops.shape(special_math_ops.lbeta(x)))) self.assertEqual( tensor_shape.TensorShape([3, 2]), special_math_ops.lbeta(x).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_length_1_last_dimension_results_in_one(self): # If there is only one coefficient, the formula still works, and we get one # as the answer, always. x_a = [5.5] x_b = [0.1] with self.test_session(use_gpu=True): - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_a)).eval()) - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_b)).eval()) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_a)))) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_b)))) self.assertEqual((), special_math_ops.lbeta(x_a).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank1_returns_negative_infinity(self): with self.test_session(use_gpu=True): x = constant_op.constant([], shape=[0]) lbeta_x = special_math_ops.lbeta(x) expected_result = constant_op.constant(-np.inf, shape=()) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank2_with_zero_last_dim_returns_negative_infinity(self): with self.test_session(use_gpu=True): event_size = 0 @@ -135,9 +148,11 @@ class LBetaTest(test.TestCase): lbeta_x = special_math_ops.lbeta(x) expected_result = constant_op.constant(-np.inf, shape=[batch_size]) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank2_with_zero_batch_dim_returns_empty(self): with self.test_session(use_gpu=True): batch_size = 0 @@ -147,12 +162,14 @@ class LBetaTest(test.TestCase): expected_result = constant_op.constant([], shape=[batch_size]) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) class BesselTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def test_bessel_i0(self): x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32) x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) @@ -165,6 +182,7 @@ class BesselTest(test.TestCase): except ImportError as e: tf_logging.warn('Cannot test special functions: %s' % str(e)) + @test_util.run_in_graph_and_eager_modes def test_bessel_i1(self): x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32) x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) @@ -316,7 +334,7 @@ class EinsumTest(test.TestCase): output_tensor = special_math_ops.einsum(axes, *input_tensors) with self.test_session(use_gpu=True): - output_value = output_tensor.eval() + output_value = self.evaluate(output_tensor) correct_value = np.einsum(axes, *input_vals) diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py index 28054f50ef3b1227f12376b4b3700a7618270d65..293aace7282eb0f8dde9da75b0d353a560c0ecb9 100644 --- a/tensorflow/python/ops/spectral_ops.py +++ b/tensorflow/python/ops/spectral_ops.py @@ -167,8 +167,8 @@ def _validate_dct_arguments(dct_type, n, axis, norm): raise NotImplementedError("The DCT length argument is not implemented.") if axis != -1: raise NotImplementedError("axis must be -1. Got: %s" % axis) - if dct_type != 2: - raise ValueError("Only the Type II DCT is supported.") + if dct_type not in (2, 3): + raise ValueError("Only Types II and III (I)DCT are supported.") if norm not in (None, "ortho"): raise ValueError( "Unknown normalization. Expected None or 'ortho', got: %s" % norm) @@ -179,18 +179,20 @@ def _validate_dct_arguments(dct_type, n, axis, norm): def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. - Currently only Type II is supported. Implemented using a length `2N` padded - @{tf.spectral.rfft}, as described here: https://dsp.stackexchange.com/a/10606 + Currently only Types II and III are supported. Type II is implemented using a + length `2N` padded @{tf.spectral.rfft}, as described here: + https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward + inverse of Type II (i.e. using a length `2N` padded @{tf.spectral.irfft}). @compatibility(scipy) - Equivalent to scipy.fftpack.dct for the Type-II DCT. + Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT. https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html @end_compatibility Args: input: A `[..., samples]` `float32` `Tensor` containing the signals to take the DCT of. - type: The DCT type to perform. Must be 2. + type: The DCT type to perform. Must be 2 or 3. n: For future expansion. The length of the transform. Must be `None`. axis: For future expansion. The axis to compute the DCT along. Must be `-1`. norm: The normalization to apply. `None` for no normalization or `'ortho'` @@ -201,8 +203,8 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl A `[..., samples]` `float32` `Tensor` containing the DCT of `input`. Raises: - ValueError: If `type` is not `2`, `n` is not `None, `axis` is not `-1`, or - `norm` is not `None` or `'ortho'`. + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform """ @@ -214,22 +216,91 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl axis_dim = input.shape[-1].value or _array_ops.shape(input)[-1] axis_dim_float = _math_ops.to_float(axis_dim) - scale = 2.0 * _math_ops.exp(_math_ops.complex( - 0.0, -_math.pi * _math_ops.range(axis_dim_float) / - (2.0 * axis_dim_float))) - - # TODO(rjryan): Benchmark performance and memory usage of the various - # approaches to computing a DCT via the RFFT. - dct2 = _math_ops.real( - rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) - - if norm == "ortho": - n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) - n2 = n1 * _math_ops.sqrt(2.0) - # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. - weights = _array_ops.pad( - _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], - constant_values=n2) - dct2 *= weights - - return dct2 + if type == 2: + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + + # TODO(rjryan): Benchmark performance and memory usage of the various + # approaches to computing a DCT via the RFFT. + dct2 = _math_ops.real( + rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) + + if norm == "ortho": + n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(2.0) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + dct2 *= weights + + return dct2 + + elif type == 3: + if norm == "ortho": + n1 = _math_ops.sqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(0.5) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + input *= weights + else: + input *= axis_dim_float + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, + _math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + dct3 = _math_ops.real( + irfft( + scale * _math_ops.complex(input, 0.0), + fft_length=[2 * axis_dim]))[..., :axis_dim] + + return dct3 + + +# TODO(rjryan): Implement `type`, `n` and `axis` parameters. +@tf_export("spectral.idct") +def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin + """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. + + Currently only Types II and III are supported. Type III is the inverse of + Type II, and vice versa. + + Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is + not `'ortho'`. That is: + `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`. + When `norm='ortho'`, we have: + `signal == idct(dct(signal, norm='ortho'), norm='ortho')`. + + @compatibility(scipy) + Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT. + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html + @end_compatibility + + Args: + input: A `[..., samples]` `float32` `Tensor` containing the signals to take + the DCT of. + type: The IDCT type to perform. Must be 2 or 3. + n: For future expansion. The length of the transform. Must be `None`. + axis: For future expansion. The axis to compute the DCT along. Must be `-1`. + norm: The normalization to apply. `None` for no normalization or `'ortho'` + for orthonormal normalization. + name: An optional name for the operation. + + Returns: + A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`. + + Raises: + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. + + [idct]: + https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms + """ + _validate_dct_arguments(type, n, axis, norm) + inverse_type = {2: 3, 3: 2}[type] + return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index a2d24711e2291bafcf5736c6206ceb09ac210453..d0e5f700254fa5273cb707e59ac0d141fdc13627 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import cudnn_rnn_grad from tensorflow.python.ops import data_flow_grad from tensorflow.python.ops import manip_grad from tensorflow.python.ops import math_grad +from tensorflow.python.ops import random_grad from tensorflow.python.ops import sparse_grad from tensorflow.python.ops import spectral_grad from tensorflow.python.ops import state_grad diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 08b7cda73bdc739912ec58f161ec7113aeffd9e8..8cb6a0537e928effbcf4c475bcc4e974182da2a7 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -394,7 +394,7 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None): A tensor of indices into the first dimension of `ref`. updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated values to store in `ref`. - use_locking: An optional `bool`. Defaults to `True`. + use_locking: An optional `bool`. Defaults to `False`. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. name: A name for the operation (optional). @@ -458,7 +458,7 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): A tensor of indices into ref. updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated values to add to ref. - use_locking: An optional `bool`. Defaults to `True`. + use_locking: An optional `bool`. Defaults to `False`. An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index b80f84eb7cde264c5a7c83eafacc344adb50b80a..00150fe68820da711c76f642baced45163a8727c 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -306,10 +306,11 @@ def create_db_writer(db_uri, def _make_summary_writer(name, factory, **kwargs): resource = gen_summary_ops.summary_writer(shared_name=name) init_op_fn = lambda: factory(resource, **kwargs) - # TODO(apassos): Consider doing this instead. - # if not context.executing_eagerly(): - # ops.get_default_session().run(init_op) - ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op_fn()) + init_op = init_op_fn() + if not context.executing_eagerly(): + # TODO(apassos): Consider doing this instead. + # ops.get_default_session().run(init_op) + ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op) return SummaryWriter(resource, init_op_fn) @@ -380,7 +381,8 @@ def summary_writer_function(name, tensor, function, family=None): with ops.device("cpu:0"): op = smart_cond.smart_cond( should_record_summaries(), record, _nothing, name="") - ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access + if not context.executing_eagerly(): + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access return op diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 4be9f5eb6864015cd9c3f6f3526285ebbdc180f9..d3172838a4e25bfd8ca10e15991aeba47ff44192 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -1093,39 +1093,40 @@ class Variable(checkpointable.CheckpointableBase): def __imul__(self, other): logging.log_first_n( logging.WARN, - "Variable *= will be deprecated. Use variable.assign_mul" - " if you want assignment to the variable value or 'x = x * y'" + "Variable *= will be deprecated. Use `var.assign(var * other)`" + " if you want assignment to the variable value or `x = x * y`" " if you want a new python Tensor object.", 1) return self * other def __idiv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __itruediv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __irealdiv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __ipow__(self, other): logging.log_first_n( logging.WARN, - "Variable **= will be deprecated. Use 'x = x ** y'" + "Variable **= will be deprecated. Use `var.assign(var ** other)`" + " if you want assignment to the variable value or `x = x ** y`" " if you want a new python Tensor object.", 1) return self ** other diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index 9e49188c1ef353d345c97ea0295aa1a68283605e..f9891f3b1e2e94f61329babd1409e3efacc7f5b3 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -707,8 +707,10 @@ class PrintModelAnalysisTest(test.TestCase): a = array_ops.constant(np.ones((100, 100))) b = array_ops.constant(np.ones((100, 100))) c = a * b + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.min_graph_nodes = -1 - with session.Session() as sess: + with session.Session(config=config) as sess: run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 500dc30cc30f757965791e504bc79718bb7f7bd7..5d7535cf34f7396b7ff6aebd3984046e51c98347 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -59,6 +59,7 @@ limitations under the License. %rename("%s") TFE_ContextOptionsSetConfig; %rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy; %rename("%s") TFE_ContextOptionsSetAsync; +%rename("%s") TFE_ContextOptionsSetServerDef; %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; %rename("%s") TFE_Py_TensorShapeOnDevice; diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 81786fbf435ffebba6217c0a03f06494195afc3c..076f2d8760fe00035ef5830a02d22e82c54dd768 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -87,6 +87,30 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "loader_test", + size = "small", + srcs = ["loader_test.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":builder", + ":loader", + ":signature_def_utils", + ":utils", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index d1bd8d47aee94fd913b807860eff7fa94bb469e5..e5f649fdabb5cc2600a6fdd0e5ed9950d6bb23c2 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -28,6 +28,7 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants from tensorflow.python.training import saver as tf_saver @@ -207,11 +208,56 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): Raises: RuntimeError: MetaGraphDef associated with the tags cannot be found. """ - with sess.graph.as_default(): - # Build the SavedModel protocol buffer and find requested meta graph def. - saved_model = _parse_saved_model(export_dir) + loader = SavedModelLoader(export_dir) + return loader.load(sess, tags, import_scope, **saver_kwargs) + + +class SavedModelLoader(object): + """Load graphs and restore variable values from a `SavedModel`.""" + + def __init__(self, export_dir): + """Creates a `SavedModelLoader`. + + Args: + export_dir: Directory in which the SavedModel protocol buffer and + variables to be loaded are located. + """ + self._export_dir = export_dir + self._variables_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.VARIABLES_DIRECTORY), + compat.as_bytes(constants.VARIABLES_FILENAME)) + self._saved_model = _parse_saved_model(export_dir) + + @property + def export_dir(self): + """Directory containing the SavedModel.""" + return self._export_dir + + @property + def variables_path(self): + """Path to variable checkpoint files.""" + return self._variables_path + + @property + def saved_model(self): + """SavedModel object parsed from the export directory.""" + return self._saved_model + + def get_meta_graph_def_from_tags(self, tags): + """Return MetaGraphDef with the exact specified tags. + + Args: + tags: A list or set of string tags that identify the MetaGraphDef. + + Returns: + MetaGraphDef with the same tags. + + Raises: + RuntimeError: if no metagraphs were found with the associated tags. + """ found_match = False - for meta_graph_def in saved_model.meta_graphs: + for meta_graph_def in self._saved_model.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set(tags): meta_graph_def_to_load = meta_graph_def found_match = True @@ -223,32 +269,100 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): " could not be found in SavedModel. To inspect available tag-sets in" " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" ) + return meta_graph_def_to_load - # Build a saver by importing the meta graph def to load. - saver = tf_saver.import_meta_graph( - meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs) - - if saver: - # Build the checkpoint path where the variables are located. - variables_path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) - - # Restore the variables using the built saver in the provided session. - saver.restore(sess, variables_path) - else: - tf_logging.info("The specified SavedModel has no variables; no " - "checkpoints were restored.") - - # Get asset tensors, if any. - asset_tensors_dictionary = _get_asset_tensors( - export_dir, meta_graph_def_to_load, import_scope=import_scope) - - main_op_tensor = ( - _get_main_op_tensor(meta_graph_def_to_load) or - (_get_legacy_init_op_tensor(meta_graph_def_to_load))) - if main_op_tensor is not None: - sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) + def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): + """Load ops and nodes from SavedModel MetaGraph into graph. - return meta_graph_def_to_load + Args: + graph: tf.Graph object. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. + + Returns: + Saver defined by the MetaGraph, which can be used to restore the variable + values. + """ + meta_graph_def = self.get_meta_graph_def_from_tags(tags) + with graph.as_default(): + return tf_saver.import_meta_graph( + meta_graph_def, import_scope=import_scope, **saver_kwargs) + + def restore_variables(self, sess, saver, import_scope=None): + """Restore SavedModel variable values into the session. + + Args: + sess: tf.Session to restore variable values. + saver: a tf.train.Saver object. Can be None if there are no variables in + graph. This may be the saver returned by the load_graph() function, or a + default `tf.train.Saver()`. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + + Raises: + ValueError: if no saver was passed to the saver argument, and there are + variables in the graph. + """ + with sess.graph.as_default(): + if (saver is None and + not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access + tf_logging.info("The specified SavedModel has no variables; no " + "checkpoints were restored.") + elif isinstance(saver, tf_saver.Saver): + saver.restore(sess, self._variables_path) + else: + raise ValueError( + "No tf.train.Saver object was passed to the function " + "SavedModelLoader.restore_variables. Since there are variables in " + "the graph, a saver is required.") + + def run_init_ops(self, sess, tags, import_scope=None): + """Run initialization ops defined in the `MetaGraphDef`. + + Args: + sess: tf.Session to restore variable values. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + """ + meta_graph_def = self.get_meta_graph_def_from_tags(tags) + with sess.graph.as_default(): + # Get asset tensors, if any. + asset_tensors_dictionary = _get_asset_tensors( + self._export_dir, meta_graph_def, import_scope=import_scope) + + main_op_tensor = ( + _get_main_op_tensor(meta_graph_def) or + (_get_legacy_init_op_tensor(meta_graph_def))) + if main_op_tensor is not None: + sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) + + def load(self, sess, tags, import_scope=None, **saver_kwargs): + """Load the MetaGraphDef graph and restore variable values into the session. + + Args: + sess: tf.Session to restore variable values. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. + + Returns: + `MetagraphDef` proto of the graph that was loaded. + """ + with sess.graph.as_default(): + saver = self.load_graph(sess.graph, tags, import_scope, + **saver_kwargs) + self.restore_variables(sess, saver, import_scope) + self.run_init_ops(sess, tags, import_scope) + return self.get_meta_graph_def_from_tags(tags) diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ce18859f6b9e4c141c4b27f3643c8d4004eb56f6 --- /dev/null +++ b/tensorflow/python/saved_model/loader_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModelLoader class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import utils +from tensorflow.python.training import saver as tf_saver + + +def _get_export_dir(label): + return os.path.join(test.get_temp_dir(), label) + +SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model") +SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op") + + +class SavedModelLoaderTest(test.TestCase): + + def setUp(self): + """Write test SavedModels to a temp directory.""" + with session.Session(graph=ops.Graph()) as sess: + x = variables.Variable(5, name="x") + y = variables.Variable(11, name="y") + z = x + y + sess.run(variables.global_variables_initializer()) + + foo_sig_def = signature_def_utils.build_signature_def( + {"foo_input": utils.build_tensor_info(x)}, + {"foo_output": utils.build_tensor_info(z)}) + bar_sig_def = signature_def_utils.build_signature_def( + {"bar_x": utils.build_tensor_info(x), + "bar_y": utils.build_tensor_info(y)}, + {"bar_z": utils.build_tensor_info(z)}) + + builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}) + builder.save() + + # Write SavedModel with a main_op + assign_op = control_flow_ops.group(state_ops.assign(y, 7)) + + builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}, + main_op=assign_op) + builder.save() + + def tearDown(self): + file_io.delete_recursively(test.get_temp_dir()) + + def test_load_function(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + loader2.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) + + def test_load_graph(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + graph = ops.Graph() + loader.load_graph(graph, ["foo_graph"]) + + x = graph.get_tensor_by_name("x:0") + y = graph.get_tensor_by_name("y:0") + + with self.assertRaises(KeyError): + graph.get_tensor_by_name("z:0") + + with self.test_session(graph=graph) as sess: + # Check that x and y are not initialized + with self.assertRaises(errors.FailedPreconditionError): + sess.run(x) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(y) + + def test_load_with_import_scope(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz") + + # The default saver should not work when the import scope is set. + with self.assertRaises(errors.NotFoundError): + loader.restore_variables(sess, tf_saver.Saver()) + + loader.restore_variables(sess, saver) + loader.run_init_ops(sess, ["foo_graph"]) + + self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("baz/y:0").eval()) + + # Test combined load function. + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"], import_scope="baa") + self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval()) + + def test_restore_variables(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + x = variables.Variable(0, name="x") + y = variables.Variable(0, name="y") + z = x * y + + sess.run(variables.global_variables_initializer()) + + # There are variables to restore, so a saver must be created. + with self.assertRaises(ValueError): + loader.restore_variables(sess, None) + + loader.restore_variables(sess, tf_saver.Saver()) + self.assertEqual(55, z.eval()) + + def test_run_init_op(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + graph = ops.Graph() + saver = loader.load_graph(graph, ["foo_graph"]) + with self.test_session(graph=graph) as sess: + loader.restore_variables(sess, saver) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + loader.run_init_ops(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) + + def test_parse_saved_model(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"]) + self.assertIsNotNone(meta_graph) + self.assertIn("foo", meta_graph.signature_def) + self.assertIn("bar", meta_graph.signature_def) + + def test_load_invalid_meta_graph(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags([]) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags([""]) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags(["not_a_graph"]) + + def test_load_saved_model_with_no_variables(self): + """Test that SavedModel runs saver when there appear to be no variables. + + When no variables are detected, this may mean that the variables were saved + to different collections, or the collections weren't saved to the + SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still + run in either of these cases. + """ + path = _get_export_dir("no_variable_saved_model") + with session.Session(graph=ops.Graph()) as sess: + x = variables.Variable(5, name="x", collections=["not_global_variable"]) + y = variables.Variable(11, name="y", collections=["not_global_variable"]) + self.assertFalse(variables._all_saveable_objects()) + z = x + y + sess.run(variables.variables_initializer([x, y])) + + foo_sig_def = signature_def_utils.build_signature_def( + {"foo_input": utils.build_tensor_info(x)}, + {"foo_output": utils.build_tensor_info(z)}) + + builder = saved_model_builder.SavedModelBuilder(path) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def}, + saver=tf_saver.Saver([x, y])) + builder.save() + + loader = loader_impl.SavedModelLoader(path) + with self.test_session(graph=ops.Graph()) as sess: + saver = loader.load_graph(sess.graph, ["foo_graph"]) + self.assertFalse(variables._all_saveable_objects()) + self.assertIsNotNone(saver) + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 5b9d25d449d43d8420e0f30fa8b907d41171d5e5..38fed5335ef39e9832c8b47e3c872ada453aa645 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -15,7 +15,7 @@ """Command-line interface to inspect and execute a graph in a SavedModel. For detailed usages and examples, please refer to: -https://www.tensorflow.org/programmers_guide/saved_model_cli +https://www.tensorflow.org/guide/saved_model_cli """ @@ -720,7 +720,7 @@ def create_parser(): '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' \\\n' ' --outdir=/out\n\n' 'For more information about input file format, please see:\n' - 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') + 'https://www.tensorflow.org/guide/saved_model_cli\n') parser_run = subparsers.add_parser( 'run', description=run_msg, formatter_class=argparse.RawTextHelpFormatter) parser_run.add_argument( diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index 87ba4dc91c89e03ac5f2a93bedca81878f5254a6..54f359489e97471247e57187c1f3f0d7332cfc6f 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -42,21 +42,38 @@ py_test( ) py_library( - name = "data_structures_base", - srcs = ["data_structures_base.py"], + name = "tracking", + srcs = ["tracking.py"], srcs_version = "PY2AND3", deps = [ ":base", ], ) +py_test( + name = "tracking_test", + srcs = ["tracking_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":base", + ":tracking", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "layer_utils", + srcs = ["layer_utils.py"], + srcs_version = "PY2AND3", +) + py_library( name = "data_structures", srcs = ["data_structures.py"], srcs_version = "PY2AND3", deps = [ ":base", - ":data_structures_base", + ":layer_utils", ], ) @@ -83,6 +100,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base", + ":tracking", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index cfe7259e1b6d9932fff9e78049fa85554f022076..99c8098eca236549ec5cff10ad6e79badb996a7d 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -758,61 +758,3 @@ class NoDependency(object): def __init__(self, value): self.value = value - - -class NotCheckpointable(object): - """Marks instances of child classes as unsaveable using an object-based API. - - Useful for marking objects which would otherwise look checkpointable because - of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from - `NotCheckpointable` does not prevent an object from being assigned to any - attributes, but will throw an error on save/restore. - """ - pass - - -class Checkpointable(CheckpointableBase): - """Manages dependencies on other objects. - - `Checkpointable` objects may have dependencies: other `Checkpointable` objects - which should be saved if the object declaring the dependency is saved. A - correctly saveable program has a dependency graph such that if changing a - global variable affects an object (e.g. changes the behavior of any of its - methods) then there is a chain of dependencies from the influenced object to - the variable. - - Dependency edges have names, and are created implicitly when a - `Checkpointable` object is assigned to an attribute of another - `Checkpointable` object. For example: - - ``` - obj = Checkpointable() - obj.v = ResourceVariable(0.) - ``` - - The `Checkpointable` object `obj` now has a dependency named "v" on a - variable. - - `Checkpointable` objects may specify `Tensor`s to be saved and restored - directly (e.g. a `Variable` indicating how to save itself) rather than through - dependencies on other objects. See - `Checkpointable._gather_saveables_for_checkpoint` for details. - """ - - def __setattr__(self, name, value): - """Support self.foo = checkpointable syntax.""" - # Perform the attribute assignment, and potentially call other __setattr__ - # overrides such as that for tf.keras.Model. - no_dependency = isinstance(value, NoDependency) - if no_dependency: - value = value.value - super(Checkpointable, self).__setattr__(name, value) - if not no_dependency and isinstance(value, CheckpointableBase): - self._track_checkpointable( - value, name=name, - # Allow the user to switch the Checkpointable which is tracked by this - # name, since assigning a new variable to an attribute has - # historically been fine (e.g. Adam did this). - # TODO(allenl): Should this be a warning once Checkpointable save/load - # is usable? - overwrite=True) diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py index 0a274cdfed5af83a69513e9b26bf427f284a4df7..950e9c5b535a8314e1068b772f48a14b572df691 100644 --- a/tensorflow/python/training/checkpointable/base_test.py +++ b/tensorflow/python/training/checkpointable/base_test.py @@ -17,33 +17,25 @@ from __future__ import division from __future__ import print_function from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base class InterfaceTests(test.TestCase): - def testMultipleAssignment(self): - root = checkpointable.Checkpointable() - root.leaf = checkpointable.Checkpointable() - root.leaf = root.leaf - duplicate_name_dep = checkpointable.Checkpointable() + def testOverwrite(self): + root = base.CheckpointableBase() + leaf = base.CheckpointableBase() + root._track_checkpointable(leaf, name="leaf") + (current_name, current_dependency), = root._checkpoint_dependencies + self.assertIs(leaf, current_dependency) + self.assertEqual("leaf", current_name) + duplicate_name_dep = base.CheckpointableBase() with self.assertRaises(ValueError): root._track_checkpointable(duplicate_name_dep, name="leaf") - # No error; we're overriding __setattr__, so we can't really stop people - # from doing this while maintaining backward compatibility. - root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) - - def testNoDependency(self): - root = checkpointable.Checkpointable() - hasdep = checkpointable.Checkpointable() - root.hasdep = hasdep - nodep = checkpointable.Checkpointable() - root.nodep = checkpointable.NoDependency(nodep) - self.assertEqual(1, len(root._checkpoint_dependencies)) - self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) - self.assertIs(root.hasdep, hasdep) - self.assertIs(root.nodep, nodep) + (current_name, current_dependency), = root._checkpoint_dependencies + self.assertIs(duplicate_name_dep, current_dependency) + self.assertEqual("leaf", current_name) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index 69ed253fb2d874954ee7563cd8bb21add59a7318..c46585b4178cbd24dc7d2507b4b42aa823ea1305 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -21,11 +21,9 @@ import collections import six -from tensorflow.python.keras.engine import base_layer -from tensorflow.python.keras.utils import layer_utils from tensorflow.python.ops import variables from tensorflow.python.training.checkpointable import base as checkpointable_lib -from tensorflow.python.training.checkpointable import data_structures_base +from tensorflow.python.training.checkpointable import layer_utils # TODO(allenl): We could track regular Python data structures which get assigned @@ -36,8 +34,7 @@ from tensorflow.python.training.checkpointable import data_structures_base # user's updated structure, but would have no way to support restore-on-create # for those modifications). # TODO(allenl): A dictionary data structure would be good too. -class CheckpointableDataStructure( - data_structures_base.CheckpointableDataStructureBase): +class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase): """Base class for data structures which contain checkpointable objects.""" def __init__(self): @@ -56,9 +53,8 @@ class CheckpointableDataStructure( ("Only checkpointable objects (such as Layers or Optimizers) may be " "stored in a List object. Got %s, which does not inherit from " "CheckpointableBase.") % (value,)) - if isinstance(value, ( - base_layer.Layer, - data_structures_base.CheckpointableDataStructureBase)): + if (isinstance(value, CheckpointableDataStructure) + or layer_utils.is_layer(value)): if value not in self._layers: self._layers.append(value) if hasattr(value, "_use_resource_variables"): diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index b05b3a88002e31560ed6c2005fdd29f56c5227a3..ce5852dd6e1acbf36ef58a614148c12b9dbae039 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -66,7 +66,7 @@ class HasList(training.Model): class ListTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTracking(self): model = HasList() output = model(array_ops.ones([32, 2])) @@ -106,7 +106,7 @@ class ListTests(test.TestCase): model(model_input) self.assertEqual(0, len(model.updates)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLossesForwarded(self): model = HasList() model_input = array_ops.ones([32, 2]) @@ -190,7 +190,7 @@ class HasMapping(training.Model): class MappingTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTracking(self): model = HasMapping() output = model(array_ops.ones([32, 2])) diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fdcf963d326a8916ea694e678e5ccf0df30fe26a --- /dev/null +++ b/tensorflow/python/training/checkpointable/layer_utils.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities related to layer/model functionality.""" + +# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils +# once __init__ files no longer require all of tf.keras to be imported together. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def is_layer(obj): + """Implicit check for Layer-like objects.""" + # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). + return (hasattr(obj, "call") + and hasattr(obj, "build") + and hasattr(obj, "variables")) + + +def gather_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected trainable weights/variables. + """ + if not trainable: + return [] + weights = [] + for layer in sub_layers: + weights += layer.trainable_weights + trainable_extra_variables = [ + v for v in extra_variables if v.trainable] + return weights + trainable_extra_variables + + +def gather_non_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the non-trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected non-trainable weights/variables. + """ + trainable_extra_variables = [] + non_trainable_extra_variables = [] + for v in extra_variables: + if v.trainable: + trainable_extra_variables.append(v) + else: + non_trainable_extra_variables.append(v) + weights = [] + for layer in sub_layers: + weights += layer.non_trainable_weights + if not trainable: + trainable_weights = [] + for layer in sub_layers: + trainable_weights += layer.trainable_weights + return (trainable_weights + trainable_extra_variables + + weights + non_trainable_extra_variables) + return weights + non_trainable_extra_variables diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..00e14ac982358781b379a78d94da05343f88502b --- /dev/null +++ b/tensorflow/python/training/checkpointable/tracking.py @@ -0,0 +1,103 @@ +"""Dependency tracking for checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base + + +class NoDependency(object): + """Allows attribute assignment to `Checkpointable` objects with no dependency. + + Example usage: + ```python + obj = Checkpointable() + obj.has_dependency = tf.Variable(0., name="dep") + obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) + assert obj.no_dependency.name == "nodep:0" + ``` + + `obj` in this example has a dependency on the variable "dep", and both + attributes contain un-wrapped `Variable` objects. + + `NoDependency` also works with `tf.keras.Model`, but only for checkpoint + dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) + `Layer` to the attribute without a checkpoint dependency, but the `Model` will + still track the `Layer` (so it will appear in `Model.layers`, and its + variables will appear in `Model.variables`). + """ + + def __init__(self, value): + self.value = value + + +class NotCheckpointable(object): + """Marks instances of child classes as unsaveable using an object-based API. + + Useful for marking objects which would otherwise look checkpointable because + of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from + `NotCheckpointable` does not prevent an object from being assigned to any + attributes, but will throw an error on save/restore. + """ + pass + + +class Checkpointable(base.CheckpointableBase): + """Manages dependencies on other objects. + + `Checkpointable` objects may have dependencies: other `Checkpointable` objects + which should be saved if the object declaring the dependency is saved. A + correctly saveable program has a dependency graph such that if changing a + global variable affects an object (e.g. changes the behavior of any of its + methods) then there is a chain of dependencies from the influenced object to + the variable. + + Dependency edges have names, and are created implicitly when a + `Checkpointable` object is assigned to an attribute of another + `Checkpointable` object. For example: + + ``` + obj = Checkpointable() + obj.v = ResourceVariable(0.) + ``` + + The `Checkpointable` object `obj` now has a dependency named "v" on a + variable. + + `Checkpointable` objects may specify `Tensor`s to be saved and restored + directly (e.g. a `Variable` indicating how to save itself) rather than through + dependencies on other objects. See + `Checkpointable._gather_saveables_for_checkpoint` for details. + """ + + def __setattr__(self, name, value): + """Support self.foo = checkpointable syntax.""" + # Perform the attribute assignment, and potentially call other __setattr__ + # overrides such as that for tf.keras.Model. + no_dependency = isinstance(value, NoDependency) + if no_dependency: + value = value.value + super(Checkpointable, self).__setattr__(name, value) + if not no_dependency and isinstance(value, base.CheckpointableBase): + self._track_checkpointable( + value, name=name, + # Allow the user to switch the Checkpointable which is tracked by this + # name, since assigning a new variable to an attribute has + # historically been fine (e.g. Adam did this). + # TODO(allenl): Should this be a warning once Checkpointable save/load + # is usable? + overwrite=True) diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py new file mode 100644 index 0000000000000000000000000000000000000000..baf6f57efbc5c71ac3cb0d6b0a3d8f8b115fad1e --- /dev/null +++ b/tensorflow/python/training/checkpointable/tracking_test.py @@ -0,0 +1,49 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import tracking + + +class InterfaceTests(test.TestCase): + + def testMultipleAssignment(self): + root = tracking.Checkpointable() + root.leaf = tracking.Checkpointable() + root.leaf = root.leaf + duplicate_name_dep = tracking.Checkpointable() + with self.assertRaises(ValueError): + root._track_checkpointable(duplicate_name_dep, name="leaf") + # No error; we're overriding __setattr__, so we can't really stop people + # from doing this while maintaining backward compatibility. + root.leaf = duplicate_name_dep + root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + + def testNoDependency(self): + root = tracking.Checkpointable() + hasdep = tracking.Checkpointable() + root.hasdep = hasdep + nodep = tracking.Checkpointable() + root.nodep = tracking.NoDependency(nodep) + self.assertEqual(1, len(root._checkpoint_dependencies)) + self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) + self.assertIs(root.hasdep, hasdep) + self.assertIs(root.nodep, nodep) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 0608076e6d94e1737754c49088db30c313ef417a..e0f61137b1026a64a8cc9703ac33997c55f93a4f 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -39,7 +39,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saveable_object as saveable_object_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import tracking from tensorflow.python.util import deprecation from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -114,7 +115,7 @@ class _CheckpointRestoreCoordinator(object): # `node` refers to an `Optimizer`, since only these have slot variables. self.slot_restorations.setdefault( slot_reference.original_variable_node_id, []).append( - checkpointable_lib._SlotVariableRestoration( # pylint: disable=protected-access + base._SlotVariableRestoration( # pylint: disable=protected-access optimizer_id=node_index, slot_variable_id=slot_reference.slot_variable_node_id, slot_name=slot_reference.slot_name)) @@ -258,13 +259,13 @@ def object_metadata(save_path): reader = pywrap_tensorflow.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor( - checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: raise ValueError( ('The specified checkpoint "%s" does not appear to be object-based (it ' 'is missing the key "%s"). Likely it was created with a name-based ' 'saver and does not contain an object dependency graph.') % ( - save_path, checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) + save_path, base.OBJECT_GRAPH_PROTO_KEY)) object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) @@ -278,7 +279,7 @@ def _breadth_first_checkpointable_traversal(root_checkpointable): path_to_root = {root_checkpointable: ()} while to_visit: current_checkpointable = to_visit.popleft() - if isinstance(current_checkpointable, checkpointable_lib.NotCheckpointable): + if isinstance(current_checkpointable, tracking.NotCheckpointable): raise NotImplementedError( ("The object %s does not support object-based saving. File a feature " "request if this limitation bothers you. In the meantime, you can " @@ -1038,11 +1039,11 @@ class CheckpointableSaver(object): with ops.device("/cpu:0"): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string) - assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables + assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables named_variables.append( _NoRestoreSaveable( tensor=object_graph_tensor, - name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) + name=base.OBJECT_GRAPH_PROTO_KEY)) if (self._last_save_object_graph != graph_proto # When executing eagerly, we need to re-create SaveableObjects each time # save() is called so they pick up new Tensors passed to their @@ -1132,7 +1133,7 @@ class CheckpointableSaver(object): dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( - checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. @@ -1178,7 +1179,7 @@ class CheckpointableSaver(object): "file a feature request if this limitation bothers you.") self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto - checkpointable_lib._CheckpointPosition( # pylint: disable=protected-access + base._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus( checkpoint, @@ -1188,7 +1189,7 @@ class CheckpointableSaver(object): @tf_export("train.Checkpoint") -class Checkpoint(checkpointable_lib.Checkpointable): +class Checkpoint(tracking.Checkpointable): """Groups checkpointable objects, saving and restoring them. `Checkpoint`'s constructor accepts keyword arguments whose values are types @@ -1290,7 +1291,7 @@ class Checkpoint(checkpointable_lib.Checkpointable): """ super(Checkpoint, self).__init__() for k, v in sorted(kwargs.items(), key=lambda item: item[0]): - if not isinstance(v, checkpointable_lib.CheckpointableBase): + if not isinstance(v, base.CheckpointableBase): raise ValueError( ("`Checkpoint` was expecting a checkpointable object (an object " "derived from `CheckpointableBase`), got %s. If you believe this " @@ -1309,7 +1310,7 @@ class Checkpoint(checkpointable_lib.Checkpointable): with ops.device("/cpu:0"): # add_variable creates a dependency named "save_counter"; NoDependency # prevents creating a second dependency named "_save_counter". - self._save_counter = checkpointable_lib.NoDependency( + self._save_counter = tracking.NoDependency( add_variable(self, name="save_counter", initializer=0, dtype=dtypes.int64)) diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index e2115417c44058237b19fb66b17de44b229f1ac4..896ea47b974a334d34e520e6f3c2ad947dea12a2 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -44,11 +44,12 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import adam from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util as checkpointable_utils -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -136,7 +137,7 @@ class InterfaceTests(test.TestCase): def testInitNotCalled(self): - class NoInit(checkpointable.Checkpointable): + class NoInit(tracking.Checkpointable): def __init__(self): pass @@ -145,7 +146,7 @@ class InterfaceTests(test.TestCase): checkpointable_utils.add_variable(NoInit(), "var", shape=[]) def testShapeDtype(self): - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() v1 = checkpointable_utils.add_variable( root, name="v1", initializer=3., dtype=dtypes.float64) self.assertEqual(dtypes.float64, v1.dtype) @@ -177,7 +178,7 @@ class InterfaceTests(test.TestCase): def testNotCheckpointable(self): class CallsFunctionalStuff( - checkpointable.NotCheckpointable, checkpointable.Checkpointable): + tracking.NotCheckpointable, tracking.Checkpointable): pass test_dir = self.get_temp_dir() @@ -187,7 +188,7 @@ class InterfaceTests(test.TestCase): checkpoint.save(prefix) class CallsFunctionalStuffOtherMRO( - checkpointable.Checkpointable, checkpointable.NotCheckpointable): + tracking.Checkpointable, tracking.NotCheckpointable): pass checkpoint_reversed = checkpointable_utils.Checkpoint( @@ -217,7 +218,7 @@ class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject): self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): +class _OwnsMirroredVariables(base.CheckpointableBase): """A Checkpointable object which returns a more complex SaveableObject.""" def __init__(self): @@ -232,7 +233,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): primary_variable=self.non_dep_variable, mirrored_variable=self.mirrored, name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + return {base.VARIABLE_VALUE_KEY: _saveable_factory} # The Saver sorts by name before parsing, so we need a name property. @property @@ -355,7 +356,7 @@ class CheckpointingTests(test.TestCase): optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturned(self): v = _OwnsMirroredVariables() checkpoint = checkpointable_utils.Checkpoint(v=v) @@ -375,7 +376,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(44., self.evaluate(v.non_dep_variable)) self.assertEqual(44., self.evaluate(v.mirrored)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturnedWithGlobalName(self): # The same object can also be saved using the name-based saver. v = _OwnsMirroredVariables() @@ -391,7 +392,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(42., self.evaluate(v.non_dep_variable)) self.assertEqual(42., self.evaluate(v.mirrored)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) @@ -512,7 +513,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(training_continuation + 1, session.run(root.save_counter)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. @@ -546,7 +547,7 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: disable=cell-var-from-loop - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() @@ -590,7 +591,7 @@ class CheckpointingTests(test.TestCase): # pylint: enable=cell-var-from-loop def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() checkpointable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( @@ -611,18 +612,18 @@ class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNumberedPath(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() + root = tracking.Checkpointable() + leaf = tracking.Checkpointable() root.leaf = leaf checkpointable_utils.add_variable(leaf, name="v", shape=[]) (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( root, saveables_cache=None) self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLocalNameValidation(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() + root = tracking.Checkpointable() + leaf = tracking.Checkpointable() # Dots are escaped, which avoids conflicts with reserved names. root._track_checkpointable(leaf, name=".ATTRIBUTES") checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) @@ -660,16 +661,16 @@ class CheckpointingTests(test.TestCase): optimizer.apply_gradients( [(g, v) for g, v in zip(grad, model.vars)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLateDependencyTracking(self): - class Dependency(checkpointable.Checkpointable): + class Dependency(tracking.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) - class LateDependencies(checkpointable.Checkpointable): + class LateDependencies(tracking.Checkpointable): def add_dep(self): self.dep = Dependency() @@ -692,16 +693,16 @@ class CheckpointingTests(test.TestCase): status.run_restore_ops() self.assertEqual(123., self.evaluate(load_into.dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDepAfterVar(self): - class Dependency(checkpointable.Checkpointable): + class Dependency(tracking.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) - class DepAfterVar(checkpointable.Checkpointable): + class DepAfterVar(tracking.Checkpointable): def add_dep(self): dep = Dependency() @@ -724,11 +725,11 @@ class CheckpointingTests(test.TestCase): status.run_restore_ops() self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -751,7 +752,7 @@ class CheckpointingTests(test.TestCase): 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = checkpointable.Checkpointable() + new_root = tracking.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( @@ -789,11 +790,11 @@ class CheckpointingTests(test.TestCase): self.evaluate(train_op) slot_status.assert_consumed() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOverlappingRestores(self): checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep = tracking.Checkpointable() save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) @@ -802,13 +803,13 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = saver.save(os.path.join(checkpoint_directory, "second")) - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() + first_root = tracking.Checkpointable() + second_root = tracking.Checkpointable() first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) - load_dep = checkpointable.Checkpointable() + load_dep = tracking.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep @@ -822,13 +823,13 @@ class CheckpointingTests(test.TestCase): # Try again with the order of the restore() reversed. The last restore # determines the final value. - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() + first_root = tracking.Checkpointable() + second_root = tracking.Checkpointable() second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) - load_dep = checkpointable.Checkpointable() + load_dep = tracking.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep @@ -840,39 +841,39 @@ class CheckpointingTests(test.TestCase): second_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAmbiguousLoad(self): # Not OK to split one checkpoint object into two checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() - dep_three = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep_one = tracking.Checkpointable() + save_root.dep_two = tracking.Checkpointable() + dep_three = tracking.Checkpointable() save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) - load_root = checkpointable.Checkpointable() + load_root = tracking.Checkpointable() status = checkpointable_utils.CheckpointableSaver(load_root).restore( save_path) - load_root.dep_one = checkpointable.Checkpointable() - load_root.dep_two = checkpointable.Checkpointable() - load_root.dep_one.dep_three = checkpointable.Checkpointable() - load_root.dep_two.dep_three = checkpointable.Checkpointable() + load_root.dep_one = tracking.Checkpointable() + load_root.dep_two = tracking.Checkpointable() + load_root.dep_one.dep_three = tracking.Checkpointable() + load_root.dep_two.dep_three = tracking.Checkpointable() checkpointable_utils.add_variable( load_root.dep_one.dep_three, name="var", initializer=0.) with self.assertRaises(AssertionError): status.assert_consumed() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testObjectsCombined(self): # Currently fine to load two checkpoint objects into one Python object checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep_one = tracking.Checkpointable() + save_root.dep_two = tracking.Checkpointable() checkpointable_utils.add_variable( save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( @@ -880,8 +881,8 @@ class CheckpointingTests(test.TestCase): self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) - load_root = checkpointable.Checkpointable() - load_root.dep_one = checkpointable.Checkpointable() + load_root = tracking.Checkpointable() + load_root.dep_one = tracking.Checkpointable() load_root.dep_two = load_root.dep_one v1 = checkpointable_utils.add_variable( load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) @@ -893,12 +894,12 @@ class CheckpointingTests(test.TestCase): self.assertEqual(32., self.evaluate(v1)) self.assertEqual(64., self.evaluate(v2)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDependencyLoop(self): # Note: this test creates garbage during eager execution because it # purposefully creates a reference cycle. - first = checkpointable.Checkpointable() - second = checkpointable.Checkpointable() + first = tracking.Checkpointable() + second = tracking.Checkpointable() first.second = second second.first = first first.v = checkpointable_utils.add_variable( @@ -911,10 +912,10 @@ class CheckpointingTests(test.TestCase): os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading - first_load = checkpointable.Checkpointable() + first_load = tracking.Checkpointable() status = checkpointable_utils.CheckpointableSaver( first_load).restore(save_path) - second_load = checkpointable.Checkpointable() + second_load = tracking.Checkpointable() first_load.second = second_load second_load.first = first_load with self.assertRaises(AssertionError): @@ -939,13 +940,13 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRestoreOnAssign(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_graph = ops.Graph() with save_graph.as_default(), self.test_session(save_graph): - first = checkpointable.Checkpointable() + first = tracking.Checkpointable() first.var1 = variable_scope.get_variable( name="outside_var", initializer=0.) first.var2 = variable_scope.get_variable( @@ -956,7 +957,7 @@ class CheckpointingTests(test.TestCase): checkpoint_prefix) restore_graph = ops.Graph() with restore_graph.as_default(), self.test_session(restore_graph): - second = checkpointable.Checkpointable() + second = tracking.Checkpointable() second.var2 = variable_scope.get_variable( name="blah", initializer=0.) status = checkpointable_utils.CheckpointableSaver( @@ -978,7 +979,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -989,11 +990,11 @@ class CheckpointingTests(test.TestCase): saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCheckpointCleanup(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(obj)) saver = checkpointable_utils.Checkpoint(obj=obj) @@ -1009,11 +1010,11 @@ class CheckpointingTests(test.TestCase): expected_filenames, os.listdir(checkpoint_directory)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCheckpointCleanupChangingVarList(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(obj)) checkpoint = checkpointable_utils.Checkpoint(obj=obj) @@ -1062,7 +1063,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -1132,7 +1133,7 @@ class CheckpointingTests(test.TestCase): beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta1_power)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_sequential(self): model = sequential.Sequential() checkpoint = checkpointable_utils.Checkpoint(model=model) @@ -1164,7 +1165,7 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([1., 2., 3., 4., 5.], self.evaluate(deferred_second_dense.bias)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_initialize_if_not_restoring(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -1243,7 +1244,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(42., self.evaluate(optimizer.variables()[0])) -class _ManualScope(checkpointable.Checkpointable): +class _ManualScope(tracking.Checkpointable): def __call__(self): with variable_scope.variable_scope("ManualScope") as vs: @@ -1257,7 +1258,7 @@ class _ManualScope(checkpointable.Checkpointable): class TemplateTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore(self): def _templated(): @@ -1308,7 +1309,7 @@ class TemplateTests(test.TestCase): self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore_nested(self): def _inner_template(): @@ -1409,7 +1410,7 @@ class CheckpointCompatibilityTests(test.TestCase): sess=session, save_path=checkpoint_prefix, global_step=root.optimizer_step) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" with test_util.device(use_gpu=True): @@ -1471,7 +1472,7 @@ class CheckpointCompatibilityTests(test.TestCase): class PythonMetadataTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveLoad(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index caffd042a0917209c87cab8993169dc4bc956039..6a326b65bbe956953bd414c8e89fd9f5cce58f48 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import threading -import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops @@ -527,9 +526,13 @@ class DistributionStrategy(object): V(`v`), output will have locality V(`v`) as well. * `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower context, like `d.update()` except with locality N. - * `d.fetch(t)`: Copy `t` with any locality to the client's CPU device. - TODO(josh11b): Deprecate `fetch`, switch to `read_var` for - reading tower-local variables. + * `d.read_var(v)`: Gets the (read-only) value of the variable `v` (on + the device determined by the current device scope), aggregating + across towers for tower-local variables. Frequently, this will be + done automatically when using `v` in an expression or fetching it in + a cross-tower context, but this function can be used to force that + conversion happens at a particular point in time (for example, to + add the result of the conversion to a graph collection). The standard pattern for updating variables is to: @@ -616,13 +619,13 @@ class DistributionStrategy(object): There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead, when saving them - or calling `fetch()/read_var()`, we use the value that - results when calling `reduce()` on all the towers' variables. + or calling `read_var()`, we use the value that results when + calling `reduce()` on all the towers' variables. Note: tower-local implies not trainable. Instead, it is expected that each tower will directly update (using `assign_add()` or whatever) its local variable instance but only the aggregated - value (accessible using `fetch()`) will be exported from the + value (accessible using `read_var()`) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce communication overhead by using tower-local variables. @@ -914,32 +917,6 @@ class DistributionStrategy(object): def _update_non_slot(self, colocate_with, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") - def fetch(self, val, destination="/device:CPU:0", fn=lambda x: x): - """Return a copy of `val` or `fn(val)` on `destination`. - - This is useful for getting a mirrored value onto a device. It - will attempt to avoid a copy by checking if the value is already - on the destination device. - - TODO(josh11b): Switch to `read_var`. - - Args: - val: Value (which may be mirrored) to copy. - destination: A device string to copy the value to. - fn: An optional function to apply to the value on the source - device, before copying. - - Returns: - A `Tensor` on `destination`. - """ - _require_cross_tower_context(self) - assert isinstance(destination, six.string_types) - destination = device_util.resolve(destination) - return self._fetch(val, destination, fn) - - def _fetch(self, val, destination, fn): - raise NotImplementedError("must be implemented in descendants") - def unwrap(self, value): """Returns the list of all per-device values contained in `value`. @@ -1219,12 +1196,6 @@ class _DefaultDistributionStrategy(DistributionStrategy): def read_var(self, tower_local_var): return array_ops.identity(tower_local_var) - def _fetch(self, var, destination, fn): - with ops.colocate_with(var): - var = fn(var) - with ops.device(destination): - return array_ops.identity(var) - def _unwrap(self, distributed_value): return [distributed_value] diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 10ab4c1137ff226d88902143d4f2281ad77de531..51190264e81ad177c56a6864b616aee52d954c43 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -19,6 +19,7 @@ from __future__ import print_function import math +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -87,6 +88,12 @@ def exponential_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for exponential_decay.") @@ -95,14 +102,22 @@ def exponential_decay(learning_rate, [learning_rate, global_step, decay_steps, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - return math_ops.multiply( - learning_rate, math_ops.pow(decay_rate, p), name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + return math_ops.multiply( + learning_rate, math_ops.pow(decay_rate, p), name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.piecewise_constant") @@ -141,48 +156,62 @@ def piecewise_constant(x, boundaries, values, name=None): ValueError: if types of `x` and `boundaries` do not match, or types of all `values` do not match or the number of elements in the lists does not match. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if len(boundaries) != len(values) - 1: raise ValueError( "The length of boundaries should be 1 less than the length of values") with ops.name_scope(name, "PiecewiseConstant", [x, boundaries, values, name]) as name: - x = ops.convert_to_tensor(x) - # Avoid explicit conversion to x's dtype. This could result in faulty - # comparisons, for example if floats are converted to integers. boundaries = ops.convert_n_to_tensor(boundaries) - for i, b in enumerate(boundaries): - if b.dtype.base_dtype != x.dtype.base_dtype: - # We can promote int32 boundaries to int64 without loss of precision. - # This covers the most common case where the user passes in boundaries - # as an array of Python integers. - if (b.dtype.base_dtype == dtypes.int32 and - x.dtype.base_dtype == dtypes.int64): - b = math_ops.cast(b, x.dtype.base_dtype) - boundaries[i] = b - else: - raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % - (b.dtype.base_dtype, x.dtype.base_dtype)) - # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) - for v in values[1:]: - if v.dtype.base_dtype != values[0].dtype.base_dtype: - raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % - (values[0].dtype.base_dtype, v.dtype.base_dtype)) - pred_fn_pairs = [] - pred_fn_pairs.append((x <= boundaries[0], lambda: values[0])) - pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1])) - for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): - # Need to bind v here; can do this with lambda v=v: ... - pred = (x > low) & (x <= high) - pred_fn_pairs.append((pred, lambda v=v: v)) - - # The default isn't needed here because our conditions are mutually - # exclusive and exhaustive, but tf.case requires it. - default = lambda: values[0] - return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + x_recomp = ops.convert_to_tensor(x) + # Avoid explicit conversion to x's dtype. This could result in faulty + # comparisons, for example if floats are converted to integers. + for i, b in enumerate(boundaries): + if b.dtype.base_dtype != x_recomp.dtype.base_dtype: + # We can promote int32 boundaries to int64 without loss of precision. + # This covers the most common case where the user passes in boundaries + # as an array of Python integers. + if (b.dtype.base_dtype == dtypes.int32 and + x_recomp.dtype.base_dtype == dtypes.int64): + b = math_ops.cast(b, x_recomp.dtype.base_dtype) + boundaries[i] = b + else: + raise ValueError( + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) + # TODO(rdipietro): Ensure that boundaries' elements strictly increases. + for v in values[1:]: + if v.dtype.base_dtype != values[0].dtype.base_dtype: + raise ValueError( + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) + pred_fn_pairs = [] + pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) + pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) + for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): + # Need to bind v here; can do this with lambda v=v: ... + pred = (x_recomp > low) & (x_recomp <= high) + pred_fn_pairs.append((pred, lambda v=v: v)) + + # The default isn't needed here because our conditions are mutually + # exclusive and exhaustive, but tf.case requires it. + default = lambda: values[0] + return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.polynomial_decay") @@ -263,6 +292,12 @@ def polynomial_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for polynomial_decay.") @@ -272,27 +307,35 @@ def polynomial_decay(learning_rate, ]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) - decay_steps = math_ops.cast(decay_steps, dtype) end_learning_rate = math_ops.cast(end_learning_rate, dtype) power = math_ops.cast(power, dtype) - if cycle: - # Find the first multiple of decay_steps that is bigger than global_step. - # If global_step is zero set the multiplier to 1 - multiplier = control_flow_ops.cond( - math_ops.equal(global_step, 0), lambda: 1.0, - lambda: math_ops.ceil(global_step / decay_steps)) - decay_steps = math_ops.multiply(decay_steps, multiplier) - else: - # Make sure that the global_step used is not bigger than decay_steps. - global_step = math_ops.minimum(global_step, decay_steps) - - p = math_ops.div(global_step, decay_steps) - return math_ops.add( - math_ops.multiply(learning_rate - end_learning_rate, - math_ops.pow(1 - p, power)), - end_learning_rate, - name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + decay_steps_recomp = math_ops.cast(decay_steps, dtype) + if cycle: + # Find the first multiple of decay_steps that is bigger than + # global_step. If global_step is zero set the multiplier to 1 + multiplier = control_flow_ops.cond( + math_ops.equal(global_step_recomp, 0), lambda: 1.0, + lambda: math_ops.ceil(global_step_recomp / decay_steps)) + decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier) + else: + # Make sure that the global_step used is not bigger than decay_steps. + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + + p = math_ops.div(global_step_recomp, decay_steps_recomp) + return math_ops.add( + math_ops.multiply(learning_rate - end_learning_rate, + math_ops.pow(1 - p, power)), + end_learning_rate, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.natural_exp_decay") @@ -350,6 +393,12 @@ def natural_exp_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for natural_exp_decay.") @@ -357,14 +406,23 @@ def natural_exp_decay(learning_rate, [learning_rate, global_step, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - exponent = math_ops.exp(math_ops.multiply(math_ops.negative(decay_rate), p)) - return math_ops.multiply(learning_rate, exponent, name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + exponent = math_ops.exp( + math_ops.multiply(math_ops.negative(decay_rate), p)) + return math_ops.multiply(learning_rate, exponent, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.inverse_time_decay") @@ -432,6 +490,12 @@ def inverse_time_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for inverse_time_decay.") @@ -439,15 +503,23 @@ def inverse_time_decay(learning_rate, [learning_rate, global_step, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - const = math_ops.cast(constant_op.constant(1), learning_rate.dtype) - denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) - return math_ops.div(learning_rate, denom, name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + const = math_ops.cast(constant_op.constant(1), dtype) + denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) + return math_ops.div(learning_rate, denom, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.cosine_decay") @@ -492,6 +564,12 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("cosine decay requires global_step") @@ -499,15 +577,23 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) - global_step = math_ops.minimum(global_step, decay_steps) - completed_fraction = global_step / decay_steps - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha - return math_ops.multiply(learning_rate, decayed) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + completed_fraction = global_step_recomp / decay_steps + cosine_decayed = 0.5 * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + + decayed = (1 - alpha) * cosine_decayed + alpha + return math_ops.multiply(learning_rate, decayed) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.cosine_decay_restarts") @@ -561,6 +647,12 @@ def cosine_decay_restarts(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("cosine decay restarts requires global_step") @@ -568,40 +660,48 @@ def cosine_decay_restarts(learning_rate, learning_rate = ops.convert_to_tensor( learning_rate, name="initial_learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) first_decay_steps = math_ops.cast(first_decay_steps, dtype) alpha = math_ops.cast(alpha, dtype) t_mul = math_ops.cast(t_mul, dtype) m_mul = math_ops.cast(m_mul, dtype) - completed_fraction = global_step / first_decay_steps + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + completed_fraction = global_step_recomp / first_decay_steps - def compute_step(completed_fraction, geometric=False): - if geometric: - i_restart = math_ops.floor( - math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / - math_ops.log(t_mul)) + def compute_step(completed_fraction, geometric=False): + """Helper for `cond` operation.""" + if geometric: + i_restart = math_ops.floor( + math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / + math_ops.log(t_mul)) - sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) - completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart - else: - i_restart = math_ops.floor(completed_fraction) - completed_fraction = completed_fraction - i_restart + else: + i_restart = math_ops.floor(completed_fraction) + completed_fraction -= i_restart + + return i_restart, completed_fraction - return i_restart, completed_fraction + i_restart, completed_fraction = control_flow_ops.cond( + math_ops.equal(t_mul, 1.0), + lambda: compute_step(completed_fraction, geometric=False), + lambda: compute_step(completed_fraction, geometric=True)) - i_restart, completed_fraction = control_flow_ops.cond( - math_ops.equal(t_mul, 1.0), - lambda: compute_step(completed_fraction, geometric=False), - lambda: compute_step(completed_fraction, geometric=True)) + m_fac = m_mul**i_restart + cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + decayed = (1 - alpha) * cosine_decayed + alpha - m_fac = m_mul**i_restart - cosine_decayed = 0.5 * m_fac * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha + return math_ops.multiply(learning_rate, decayed, name=name) - return math_ops.multiply(learning_rate, decayed, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.linear_cosine_decay") @@ -664,6 +764,12 @@ def linear_cosine_decay(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("linear cosine decay requires global_step") @@ -671,21 +777,28 @@ def linear_cosine_decay(learning_rate, [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) num_periods = math_ops.cast(num_periods, dtype) - global_step = math_ops.minimum(global_step, decay_steps) alpha = math_ops.cast(alpha, dtype) beta = math_ops.cast(beta, dtype) - linear_decayed = (decay_steps - global_step) / decay_steps - completed_fraction = global_step / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + + linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta + return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) - linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta - return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.noisy_linear_cosine_decay") @@ -756,6 +869,12 @@ def noisy_linear_cosine_decay(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("noisy linear cosine decay requires global_step") @@ -763,29 +882,36 @@ def noisy_linear_cosine_decay(learning_rate, [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) - global_step = math_ops.minimum(global_step, decay_steps) initial_variance = math_ops.cast(initial_variance, dtype) variance_decay = math_ops.cast(variance_decay, dtype) num_periods = math_ops.cast(num_periods, dtype) alpha = math_ops.cast(alpha, dtype) beta = math_ops.cast(beta, dtype) - linear_decayed = (decay_steps - global_step) / decay_steps - variance = initial_variance / ( - math_ops.pow(1.0 + global_step, variance_decay)) - std = math_ops.sqrt(variance) - noisy_linear_decayed = ( - linear_decayed + - random_ops.random_normal(linear_decayed.shape, stddev=std)) - - completed_fraction = global_step / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) - noisy_linear_cosine_decayed = ( - (alpha + noisy_linear_decayed) * cosine_decayed + beta) - - return math_ops.multiply( - learning_rate, noisy_linear_cosine_decayed, name=name) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + variance = initial_variance / ( + math_ops.pow(1.0 + global_step_recomp, variance_decay)) + std = math_ops.sqrt(variance) + noisy_linear_decayed = ( + linear_decayed + random_ops.random_normal( + linear_decayed.shape, stddev=std)) + + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + noisy_linear_cosine_decayed = ( + (alpha + noisy_linear_decayed) * cosine_decayed + beta) + + return math_ops.multiply( + learning_rate, noisy_linear_cosine_decayed, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py index 60306e4f1239a759ea1f68492a1211d5f0858997..4f3cf01822c5b56c8fd05f859c3a1db302a57625 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/training/learning_rate_decay_test.py @@ -21,12 +21,9 @@ from __future__ import print_function import math from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.ops import gen_state_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import learning_rate_decay @@ -34,31 +31,35 @@ from tensorflow.python.training import learning_rate_decay class LRDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testContinuous(self): - with self.test_session(): - step = 5 - decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96) - expected = .05 * 0.96 ** (5.0 / 10.0) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + self.evaluate(variables.global_variables_initializer()) + step = 5 + decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96) + expected = .05 * 0.96**(5.0 / 10.0) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testStaircase(self): - with self.test_session(): - step = gen_state_ops.variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") - assign_100 = state_ops.assign(step, 100) - assign_1 = state_ops.assign(step, 1) - assign_2 = state_ops.assign(step, 2) - decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96, - staircase=True) - # No change to learning rate - assign_1.op.run() - self.assertAllClose(decayed_lr.eval(), .1, 1e-6) - assign_2.op.run() - self.assertAllClose(decayed_lr.eval(), .1, 1e-6) + if context.executing_eagerly(): + step = resource_variable_ops.ResourceVariable(0) + self.evaluate(variables.global_variables_initializer()) + decayed_lr = learning_rate_decay.exponential_decay( + .1, step, 3, 0.96, staircase=True) + + # No change to learning rate due to staircase + expected = .1 + self.evaluate(step.assign(1)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + expected = .1 + self.evaluate(step.assign(2)) + self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6) + # Decayed learning rate - assign_100.op.run() expected = .1 * 0.96 ** (100 // 3) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + self.evaluate(step.assign(100)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) def testVariables(self): with self.test_session(): @@ -79,38 +80,44 @@ class LRDecayTest(test_util.TensorFlowTestCase): expected = .1 * 0.96 ** (100 // 3) self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPiecewiseConstant(self): x = resource_variable_ops.ResourceVariable(-999) - def pc(): - return learning_rate_decay.piecewise_constant(x, [100, 110, 120], - [1.0, 0.1, 0.01, 0.001]) + decayed_lr = learning_rate_decay.piecewise_constant( + x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001]) self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6) self.evaluate(x.assign(100)) - self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6) self.evaluate(x.assign(105)) - self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6) self.evaluate(x.assign(110)) - self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6) self.evaluate(x.assign(120)) - self.assertAllClose(self.evaluate(pc()), 0.01, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.01, 1e-6) self.evaluate(x.assign(999)) - self.assertAllClose(self.evaluate(pc()), 0.001, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.001, 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPiecewiseConstantEdgeCases(self): x_int = resource_variable_ops.ResourceVariable( 0, dtype=variables.dtypes.int32) boundaries, values = [-1.0, 1.0], [1, 2, 3] with self.assertRaises(ValueError): - learning_rate_decay.piecewise_constant(x_int, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x_int, boundaries, values) + if context.executing_eagerly(): + decayed_lr() + x = resource_variable_ops.ResourceVariable(0.0) boundaries, values = [-1.0, 1.0], [1.0, 2, 3] with self.assertRaises(ValueError): - learning_rate_decay.piecewise_constant(x, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x, boundaries, values) + if context.executing_eagerly(): + decayed_lr() # Test that ref types are valid. if not context.executing_eagerly(): @@ -123,221 +130,205 @@ class LRDecayTest(test_util.TensorFlowTestCase): x_int64 = resource_variable_ops.ResourceVariable( 0, dtype=variables.dtypes.int64) boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] - def pc(): - return learning_rate_decay.piecewise_constant(x_int64, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x_int64, boundaries, values) self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6) self.evaluate(x_int64.assign(1)) - self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6) self.evaluate(x_int64.assign(2)) - self.assertAllClose(self.evaluate(pc()), 0.5, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.5, 1e-6) self.evaluate(x_int64.assign(3)) - self.assertAllClose(self.evaluate(pc()), 0.6, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.6, 1e-6) self.evaluate(x_int64.assign(4)) - self.assertAllClose(self.evaluate(pc()), 0.7, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.7, 1e-6) class LinearDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testHalfWay(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.0 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = lr * 0.5 - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.0 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = lr * 0.5 + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testEnd(self): - with self.test_session(): - step = 10 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 10 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testHalfWayWithEnd(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = (lr + end_lr) * 0.5 - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = (lr + end_lr) * 0.5 + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEnd(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEndWithCycle(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - cycle=True) - expected = (lr - end_lr) * 0.25 + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, cycle=True) + expected = (lr - end_lr) * 0.25 + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class SqrtDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testHalfWay(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.0 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = lr * 0.5 ** power - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.0 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = lr * 0.5**power + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testEnd(self): - with self.test_session(): - step = 10 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 10 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testHalfWayWithEnd(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = (lr - end_lr) * 0.5 ** power + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = (lr - end_lr) * 0.5**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEnd(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEndWithCycle(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power, cycle=True) - expected = (lr - end_lr) * 0.25 ** power + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power, cycle=True) + expected = (lr - end_lr) * 0.25**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class PolynomialDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testBeginWithCycle(self): - with self.test_session(): - lr = 0.001 - decay_steps = 10 - step = 0 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, - decay_steps, cycle=True) - expected = lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + lr = 0.001 + decay_steps = 10 + step = 0 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, decay_steps, cycle=True) + expected = lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class ExponentialDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, - k, decay_rate) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr * math.exp(-i / k * decay_rate) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, k, + decay_rate) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) + @test_util.run_in_graph_and_eager_modes def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, - step, - k, - decay_rate, - staircase=True) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr * math.exp(-decay_rate * (i // k)) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.natural_exp_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) class InverseDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, - step, - k, + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, step, k, decay_rate) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr / (1 + i / k * decay_rate) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + @test_util.run_in_graph_and_eager_modes def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, - step, - k, - decay_rate, - staircase=True) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr / (1 + decay_rate * (i // k)) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.inverse_time_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) class CosineDecayTest(test_util.TensorFlowTestCase): @@ -348,34 +339,35 @@ class CosineDecayTest(test_util.TensorFlowTestCase): decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) return (1.0 - alpha) * decay + alpha + @test_util.run_in_graph_and_eager_modes def testDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay( - initial_lr, step, num_training_steps) - expected = self.np_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step, + num_training_steps) + expected = self.np_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testAlpha(self): num_training_steps = 1000 initial_lr = 1.0 alpha = 0.1 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay( - initial_lr, step, num_training_steps, alpha) - expected = self.np_cosine_decay(step, num_training_steps, alpha) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step, + num_training_steps, alpha) + expected = self.np_cosine_decay(step, num_training_steps, alpha) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class CosineDecayRestartsTest(test_util.TensorFlowTestCase): + def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0, alpha=0.0): fac = 1.0 while step >= decay_steps: - step = step - decay_steps + step -= decay_steps decay_steps *= t_mul fac *= m_mul @@ -383,51 +375,51 @@ class CosineDecayRestartsTest(test_util.TensorFlowTestCase): decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) return (1.0 - alpha) * decay + alpha + @test_util.run_in_graph_and_eager_modes def testDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps) - expected = self.np_cosine_decay_restarts(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps) + expected = self.np_cosine_decay_restarts(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testAlpha(self): num_training_steps = 1000 initial_lr = 1.0 alpha = 0.1 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, alpha=alpha) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - alpha=alpha) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, alpha=alpha) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, alpha=alpha) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testMMul(self): num_training_steps = 1000 initial_lr = 1.0 m_mul = 0.9 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, m_mul=m_mul) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - m_mul=m_mul) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, m_mul=m_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, m_mul=m_mul) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testTMul(self): num_training_steps = 1000 initial_lr = 1.0 t_mul = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, t_mul=t_mul) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - t_mul=t_mul) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, t_mul=t_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, t_mul=t_mul) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class LinearCosineDecayTest(test_util.TensorFlowTestCase): @@ -444,65 +436,63 @@ class LinearCosineDecayTest(test_util.TensorFlowTestCase): cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction)) return (alpha + linear_decayed) * cosine_decayed + beta + @test_util.run_in_graph_and_eager_modes def testDefaultDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.linear_cosine_decay( - initial_lr, step, num_training_steps) - expected = self.np_linear_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.linear_cosine_decay( + initial_lr, step, num_training_steps) + expected = self.np_linear_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testNonDefaultDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.linear_cosine_decay( - initial_lr, - step, - num_training_steps, - alpha=0.1, - beta=1e-4, - num_periods=5) - expected = self.np_linear_cosine_decay( - step, - num_training_steps, - alpha=0.1, - beta=1e-4, - num_periods=5) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.linear_cosine_decay( + initial_lr, + step, + num_training_steps, + alpha=0.1, + beta=1e-4, + num_periods=5) + expected = self.np_linear_cosine_decay( + step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class NoisyLinearCosineDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDefaultNoisyLinearCosine(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - # No numerical check because of noise - decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( - initial_lr, step, num_training_steps) - decayed_lr.eval() + # No numerical check because of noise + decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( + initial_lr, step, num_training_steps) + # Cannot be deterministically tested + self.evaluate(decayed_lr) + @test_util.run_in_graph_and_eager_modes def testNonDefaultNoisyLinearCosine(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - # No numerical check because of noise - decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( - initial_lr, - step, - num_training_steps, - initial_variance=0.5, - variance_decay=0.1, - alpha=0.1, - beta=1e-4, - num_periods=5) - decayed_lr.eval() + # No numerical check because of noise + decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( + initial_lr, + step, + num_training_steps, + initial_variance=0.5, + variance_decay=0.1, + alpha=0.1, + beta=1e-4, + num_periods=5) + # Cannot be deterministically tested + self.evaluate(decayed_lr) if __name__ == "__main__": diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index cae29eea933aec1554beea8d0413fd9febcf2d94..fe9ffde11ce47e1c2ae6c96e59cc2bf0d43d9707 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -730,15 +730,15 @@ class Optimizer( if not named_slots: return None - if hasattr(var, "_mirrored_container"): + if hasattr(var, "_distributed_container"): # NOTE: If this isn't patched, then there is no `handle` in # `_resource_apply_dense`. - mirrored_container = var._mirrored_container() - assert mirrored_container is not None + distributed_container = var._distributed_container() + assert distributed_container is not None if context.executing_eagerly(): - key = mirrored_container._unique_id + key = distributed_container._unique_id else: - key = (mirrored_container.graph, mirrored_container._shared_name) + key = (distributed_container.graph, distributed_container._shared_name) # pylint: enable=protected-access mirrored_slot = named_slots.get(key, None) if mirrored_slot is None: return None @@ -839,7 +839,7 @@ class Optimizer( def _get_non_slot_variable(self, name, graph=None): non_slot = self._non_slot_dict.get((name, graph), None) - if hasattr(non_slot, "_mirrored_container"): + if hasattr(non_slot, "_distributed_container"): # This is a mirrored non-slot. In order to enable code like `_finish` # to assign to a non-slot, return the current context replica. return non_slot.get() diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 0cab6410e83ca1880a0a4a80d2cfa5c17517af95..dfe9176beaf27f3cfa945eee8693ba7c5e9551fa 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -34,7 +34,7 @@ from tensorflow.python.training import gradient_descent class OptimizerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -112,7 +112,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoVariables(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: # pylint: disable=cell-var-from-loop @@ -127,7 +127,7 @@ class OptimizerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'No.*variables'): sgd_op.minimize(loss) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -145,7 +145,7 @@ class OptimizerTest(test.TestCase): # var1 has no gradient sgd_op.minimize(loss, var_list=[var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -161,7 +161,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.minimize(loss, var_list=[var0, var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -175,7 +175,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.apply_gradients([(None, var0), (None, var1)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -215,7 +215,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): x = ops.convert_to_tensor(1.0) def f(): diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 53ed89e4ab8dca876e232209928e61ba9628eb46..1ee975fbe48e8ba724d8f40040b122c5c02aa352 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -22,7 +22,6 @@ from __future__ import print_function import collections import os.path import re -import sys import time import uuid @@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from %s", - checkpoint_dir) + raise ValueError("Invalid checkpoint state loaded from " + + checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): @@ -1706,12 +1705,17 @@ class Saver(object): save_path: Path where parameters were previously saved. Raises: - ValueError: If save_path is None. + ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") + + if not checkpoint_exists(compat.as_text(save_path)): + raise ValueError("The passed save_path is not a valid checkpoint: " + + compat.as_text(save_path)) + logging.info("Restoring parameters from %s", compat.as_text(save_path)) try: if context.executing_eagerly(): @@ -1719,23 +1723,24 @@ class Saver(object): else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - except errors.NotFoundError: - exception_type, exception_value, exception_traceback = sys.exc_info() - # The checkpoint would not be loaded successfully as is. Try to parse it - # as an object-based checkpoint. - should_reraise = False + except errors.NotFoundError as err: + # There are three common conditions that might cause this error: + # 0. The file is missing. We ignore here, as this is checked above. + # 1. This is an object-based checkpoint trying name-based loading. + # 2. The graph has been altered and a variable or other name is missing. + + # 1. The checkpoint would not be loaded successfully as is. Try to parse + # it as an object-based checkpoint. try: reader = pywrap_tensorflow.NewCheckpointReader(save_path) object_graph_string = reader.get_tensor( checkpointable.OBJECT_GRAPH_PROTO_KEY) except errors.NotFoundError: - # This is not an object-based checkpoint, or the checkpoint doesn't - # exist. Re-raise the original exception, but do it outside the except - # block so the object graph lookup isn't included in the stack trace. - should_reraise = True - if should_reraise: - six.reraise(exception_type, exception_value, exception_traceback) - del exception_traceback # avoid reference cycles + # 2. This is not an object-based checkpoint, which likely means there + # is a graph mismatch. Re-raise the original error with + # a helpful message (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. @@ -1747,6 +1752,11 @@ class Saver(object): self._restore_from_object_based_checkpoint( sess=sess, save_path=save_path, object_graph_string=object_graph_string) + except errors.InvalidArgumentError as err: + # There is a mismatch between the graph and the checkpoint being loaded. + # We add a more reasonable error message here to help users (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a mismatch between the current graph and the graph") def _restore_from_object_based_checkpoint(self, sess, save_path, object_graph_string): @@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): return meta_graph_filename +def _wrap_restore_error_with_msg(err, extra_verbiage): + err_msg = ("Restoring from checkpoint failed. This is most likely " + "due to {} from the checkpoint. Please ensure that you " + "have not altered the graph expected based on the checkpoint. " + "Original error:\n\n{}").format(extra_verbiage, err.message) + return err.__class__(err.node_def, err.op, err_msg) + + ops.register_proto_function( ops.GraphKeys.SAVERS, proto_type=saver_pb2.SaverDef, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index e3be7d868e5805cf89c111cf29ea8412521da1be..ae9c244aaf372dcbcf365cf3e6a21ae77d9ae7d0 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -24,10 +24,8 @@ import math import os import random import shutil -import sys import tempfile import time -import traceback import numpy as np import six @@ -79,7 +77,8 @@ from tensorflow.python.training import saver as saver_module from tensorflow.python.training import saver_test_utils from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable_base +from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat @@ -171,7 +170,7 @@ class SaverTest(test.TestCase): def testBasic(self): self.basicSaveRestore(variables.Variable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) @@ -252,7 +251,7 @@ class SaverTest(test.TestCase): self.assertAllEqual(w3.eval(), 3.0) self.assertAllEqual(w4.eval(), 4.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceSaveRestoreCachingDevice(self): save_path = os.path.join(self.get_temp_dir(), "resource_cache") with self.test_session(graph=ops_lib.Graph()) as sess: @@ -368,8 +367,8 @@ class SaverTest(test.TestCase): for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): with self.test_session() as sess: save = saver_module.Saver({"v0": v0}, write_version=ver) - with self.assertRaisesRegexp(errors.NotFoundError, - "Failed to find any matching files for"): + with self.assertRaisesRegexp( + ValueError, "The passed save_path is not a valid checkpoint:"): save.restore(sess, "invalid path") def testInt64(self): @@ -671,7 +670,7 @@ class SaverTest(test.TestCase): save.restore(sess, save_path) self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveWithGlobalStep(self, pad_step_number=False): save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step") global_step_int = 5 @@ -1395,7 +1394,7 @@ class KeepCheckpointEveryNHoursTest(test.TestCase): gfile.MakeDirs(test_dir) return test_dir - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes @test.mock.patch.object(saver_module, "time") def testNonSharded(self, mock_time): save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") @@ -1515,7 +1514,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase): self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNonReshapeResourceVariable(self): self._testNonReshape(resource_variable_ops.ResourceVariable) @@ -2939,7 +2938,7 @@ class ScopedGraphTest(test.TestCase): self.assertEqual(2.0, var_dict2["variable2:0"].eval()) -class _OwnsAVariableSimple(checkpointable.CheckpointableBase): +class _OwnsAVariableSimple(checkpointable_base.CheckpointableBase): """A Checkpointable object which can be saved using a tf.train.Saver.""" def __init__(self): @@ -2947,7 +2946,7 @@ class _OwnsAVariableSimple(checkpointable.CheckpointableBase): name="non_dep_variable", initializer=6., use_resource=True) def _gather_saveables_for_checkpoint(self): - return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable} + return {checkpointable_base.VARIABLE_VALUE_KEY: self.non_dep_variable} # The Saver sorts by name before parsing, so we need a name property. @property @@ -2972,7 +2971,7 @@ class _MirroringSaveable( self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): +class _OwnsMirroredVariables(checkpointable_base.CheckpointableBase): """A Checkpointable object which returns a more complex SaveableObject.""" def __init__(self): @@ -2987,7 +2986,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): primary_variable=self.non_dep_variable, mirrored_variable=self.mirrored, name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + return {checkpointable_base.VARIABLE_VALUE_KEY: _saveable_factory} # The Saver sorts by name before parsing, so we need a name property. @property @@ -2995,7 +2994,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): return self.non_dep_variable.name -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(checkpointable_tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -3021,7 +3020,7 @@ class MyModel(training.Model): class CheckpointableCompatibilityTests(test.TestCase): # TODO(allenl): Track down python3 reference cycles in these tests. - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNotSaveableButIsCheckpointable(self): v = _OwnsAVariableSimple() saver = saver_module.Saver(var_list=[v]) @@ -3034,7 +3033,7 @@ class CheckpointableCompatibilityTests(test.TestCase): saver.restore(sess, save_path) self.assertEqual(42., self.evaluate(v.non_dep_variable)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturned(self): v = _OwnsMirroredVariables() saver = saver_module.Saver(var_list=[v]) @@ -3138,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase): errors.NotFoundError, "Key b not found in checkpoint"): b_saver.restore(sess=sess, save_path=save_path) - def testCheckpointNotFoundErrorRaised(self): - # Restore does some tricky exception handling to figure out if it should - # load an object-based checkpoint. Tests that the exception handling isn't - # too broad. - a = resource_variable_ops.ResourceVariable(1., name="a") - saver = saver_module.Saver([a]) - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.NotFoundError, - "Failed to find any matching files for path_which_does_not_exist"): - saver.restore(sess=sess, save_path="path_which_does_not_exist") - try: - saver.restore(sess=sess, save_path="path_which_does_not_exist") - except errors.NotFoundError: - # Make sure we don't have a confusing "During handling of the above - # exception" block in Python 3. - # pylint: disable=no-value-for-parameter - exception_string = "\n".join( - traceback.format_exception(*sys.exc_info())) - # pylint: enable=no-value-for-parameter - self.assertNotIn("NewCheckpointReader", exception_string) + with self.assertRaises(errors.NotFoundError) as cs: + b_saver.restore(sess=sess, save_path=save_path) + + # Make sure we don't have a confusing "During handling of the above + # exception" block in Python 3. + self.assertNotIn("NewCheckpointReader", cs.exception.message) + + def testGraphChangedForRestoreErrorRaised(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + with ops_lib.Graph().as_default() as g: + a = variables.Variable(1., name="a") + a_saver = saver_module.Saver([a]) + + with self.test_session(graph=g) as sess: + sess.run(a.initializer) + save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) + + with ops_lib.Graph().as_default() as g: + a = variables.Variable([1.], name="a") + a_saver = saver_module.Saver([a]) + with self.test_session(graph=g) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "a mismatch between the current graph and the graph"): + a_saver.restore(sess=sess, save_path=save_path) def testLoadFromObjectBasedGraph(self): checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/python/util/lock_util.py b/tensorflow/python/util/lock_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0424960666323870fb1db83804857dd838cfe9ae --- /dev/null +++ b/tensorflow/python/util/lock_util.py @@ -0,0 +1,128 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Locking related utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + + +class GroupLock(object): + """A lock to allow many members of a group to access a resource exclusively. + + This lock provides a way to allow access to a resource by multiple threads + belonging to a logical group at the same time, while restricting access to + threads from all other groups. You can think of this as an extension of a + reader-writer lock, where you allow multiple writers at the same time. We + made it generic to support multiple groups instead of just two - readers and + writers. + + Simple usage example with two groups accessing the same resource: + + ```python + lock = GroupLock(num_groups=2) + + # In a member of group 0: + with lock.group(0): + # do stuff, access the resource + # ... + + # In a member of group 1: + with lock.group(1): + # do stuff, access the resource + # ... + ``` + + Using as a context manager with `.group(group_id)` is the easiest way. You + can also use the `acquire` and `release` method directly. + """ + + def __init__(self, num_groups=2): + """Initialize a group lock. + + Args: + num_groups: The number of groups that will be accessing the resource under + consideration. Should be a positive number. + + Returns: + A group lock that can then be used to synchronize code. + + Raises: + ValueError: If num_groups is less than 1. + """ + if num_groups < 1: + raise ValueError("num_groups must be a positive integer, got {}".format( + num_groups)) + self._ready = threading.Condition(threading.Lock()) + self._num_groups = num_groups + self._group_member_counts = [0] * self._num_groups + + def group(self, group_id): + """Enter a context where the lock is with group `group_id`. + + Args: + group_id: The group for which to acquire and release the lock. + + Returns: + A context manager which will acquire the lock for `group_id`. + """ + self._validate_group_id(group_id) + return self._Context(self, group_id) + + def acquire(self, group_id): + """Acquire the group lock for a specific group `group_id`.""" + self._validate_group_id(group_id) + + self._ready.acquire() + while self._another_group_active(group_id): + self._ready.wait() + self._group_member_counts[group_id] += 1 + self._ready.release() + + def release(self, group_id): + """Release the group lock for a specific group `group_id`.""" + self._validate_group_id(group_id) + + self._ready.acquire() + self._group_member_counts[group_id] -= 1 + if self._group_member_counts[group_id] == 0: + self._ready.notifyAll() + self._ready.release() + + def _another_group_active(self, group_id): + return any( + c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id) + + def _validate_group_id(self, group_id): + if group_id < 0 or group_id >= self._num_groups: + raise ValueError( + "group_id={} should be between 0 and num_groups={}".format( + group_id, self._num_groups)) + + class _Context(object): + """Context manager helper for `GroupLock`.""" + + def __init__(self, lock, group_id): + self._lock = lock + self._group_id = group_id + + def __enter__(self): + self._lock.acquire(self._group_id) + + def __exit__(self, type_arg, value_arg, traceback_arg): + del type_arg, value_arg, traceback_arg + self._lock.release(self._group_id) diff --git a/tensorflow/python/util/lock_util_test.py b/tensorflow/python/util/lock_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cda8f952259c9e117e0bd7ff3cac35e764856f43 --- /dev/null +++ b/tensorflow/python/util/lock_util_test.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lock_util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import time + +from absl.testing import parameterized + +from tensorflow.python.platform import test +from tensorflow.python.util import lock_util + + +class GroupLockTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters(1, 2, 3, 5, 10) + def testGroups(self, num_groups): + lock = lock_util.GroupLock(num_groups) + num_threads = 10 + finished = set() + + def thread_fn(thread_id): + time.sleep(random.random() * 0.1) + group_id = thread_id % num_groups + with lock.group(group_id): + time.sleep(random.random() * 0.1) + self.assertGreater(lock._group_member_counts[group_id], 0) + for g, c in enumerate(lock._group_member_counts): + if g != group_id: + self.assertEqual(0, c) + finished.add(thread_id) + + threads = [ + self.checkedThread(target=thread_fn, args=(i,)) + for i in range(num_threads) + ] + + for i in range(num_threads): + threads[i].start() + for i in range(num_threads): + threads[i].join() + + self.assertEqual(set(range(num_threads)), finished) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py index 5000bcfad05900e63bc72c1bd0e31e30434b74ae..9d9cac272592f6b73b4c78f38310d7b89a89e05d 100644 --- a/tensorflow/python/util/serialization_test.py +++ b/tensorflow/python/util/serialization_test.py @@ -47,7 +47,7 @@ class SerializationTests(test.TestCase): self.assertIs(round_trip[0], None) self.assertEqual(round_trip[1], 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_serialize_sequential(self): model = sequential.Sequential() model.add(core.Dense(4)) @@ -61,7 +61,7 @@ class SerializationTests(test.TestCase): self.assertAllEqual([1, 1], input_round_trip[0]["config"]["batch_input_shape"]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_serialize_model(self): x = input_layer.Input(shape=[3]) y = core.Dense(10)(x) diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index c68cda01002b1c5bbc2facb95b1eba214fbad7cb..e742f8e8d51d0217b631ebdc23ee65263c1ce0f0 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -2,6 +2,7 @@ licenses(["restricted"]) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") STREAM_EXECUTOR_HEADERS = glob([ "*.h", @@ -33,7 +34,6 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@local_config_cuda//cuda:cuda_headers", @@ -48,11 +48,18 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:ptr_util", - "//tensorflow/compiler/xla:statusor", "@local_config_cuda//cuda:cuda_headers", ] + if_static([":stream_executor_impl"]), ) +cc_header_only_library( + name = "stream_executor_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":stream_executor", + ], +) + cc_library( name = "cuda_platform", srcs = if_cuda_is_configured( diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 31e407f199844ff51ace55469e2abd50d2cefb07..874bf0e8cb481bf9e506e6d9b71c19afbe89d644 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -2183,8 +2183,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( // Return false if we might be hitting a cuBLAS bug that produces the wrong // result. See nvbugs/2156201, b/79126339. -#if (CUDA_VERSION >= 9000) - if (CUDA_VERSION < 9020 && algorithm != CUBLAS_GEMM_ALGO12 && +#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 + if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) && std::max({m, n, k}) >= 2097153 && cc_major < 7) { return false; } diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index 10f6d21d54a15b46ebfc6f8ad32e3e908fab9a96..124d5905b91cbf839437e763728cc76ad0d671dc 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -24,12 +24,17 @@ limitations under the License. #include #include #include +#ifdef __APPLE__ +#include +#include +#else #if !defined(PLATFORM_WINDOWS) #include #include #include #endif #include +#endif #include #include #include @@ -49,7 +54,9 @@ limitations under the License. namespace stream_executor { namespace cuda { -#if !defined(PLATFORM_WINDOWS) +#ifdef __APPLE__ +static const CFStringRef kDriverKextIdentifier = CFSTR("com.nvidia.CUDA"); +#elif !defined(PLATFORM_WINDOWS) static const char *kDriverVersionPath = "/proc/driver/nvidia/version"; #endif @@ -114,7 +121,31 @@ string Diagnostician::GetDevNodePath(int dev_node_ordinal) { } void Diagnostician::LogDiagnosticInformation() { -#if !defined(PLATFORM_WINDOWS) +#ifdef __APPLE__ + CFStringRef kext_ids[1]; + kext_ids[0] = kDriverKextIdentifier; + CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1, + &kCFTypeArrayCallBacks); + CFDictionaryRef kext_infos = + KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); + CFRelease(kext_id_query); + + CFDictionaryRef cuda_driver_info = nullptr; + if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, + (const void **)&cuda_driver_info)) { + bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue( + cuda_driver_info, CFSTR("OSBundleStarted"))); + if (!started) { + LOG(INFO) << "kernel driver is installed, but does not appear to be " + "running on this host " + << "(" << port::Hostname() << ")"; + } + } else { + LOG(INFO) << "kernel driver does not appear to be installed on this host " + << "(" << port::Hostname() << ")"; + } + CFRelease(kext_infos); +#elif !defined(PLATFORM_WINDOWS) if (access(kDriverVersionPath, F_OK) != 0) { LOG(INFO) << "kernel driver does not appear to be running on this host " << "(" << port::Hostname() << "): " @@ -168,7 +199,8 @@ void Diagnostician::LogDiagnosticInformation() { << DriverVersionStatusToString(kernel_version); #endif -#if !defined(PLATFORM_WINDOWS) + // OS X kernel driver does not report version accurately +#if !defined(__APPLE__) && !defined(PLATFORM_WINDOWS) if (kernel_version.ok() && dso_version.ok()) { WarnOnDsoKernelMismatch(dso_version, kernel_version); } @@ -182,6 +214,29 @@ port::StatusOr Diagnostician::FindDsoVersion() { port::error::NOT_FOUND, "was unable to find libcuda.so DSO loaded into this program")); +#if defined(__APPLE__) + // OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib + const string prefix("libcuda_"); + const string suffix("_mercury.dylib"); + for (uint32_t image_index = 0; image_index < _dyld_image_count(); + ++image_index) { + const string path(_dyld_get_image_name(image_index)); + const size_t suffix_pos = path.rfind(suffix); + const size_t prefix_pos = path.rfind(prefix, suffix_pos); + if (prefix_pos == string::npos || suffix_pos == string::npos) { + // no match + continue; + } + const size_t start = prefix_pos + prefix.size(); + if (start >= suffix_pos) { + // version not included + continue; + } + const size_t length = suffix_pos - start; + const string version = path.substr(start, length); + result = StringToDriverVersion(version); + } +#else #if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA) // Callback used when iterating through DSOs. Looks for the driver-interfacing // DSO and yields its version number into the callback data, when found. @@ -214,6 +269,7 @@ port::StatusOr Diagnostician::FindDsoVersion() { }; dl_iterate_phdr(iterate_phdr, &result); +#endif #endif return result; @@ -259,7 +315,41 @@ void Diagnostician::WarnOnDsoKernelMismatch( port::StatusOr Diagnostician::FindKernelDriverVersion() { -#if defined(PLATFORM_WINDOWS) +#if defined(__APPLE__) + CFStringRef kext_ids[1]; + kext_ids[0] = kDriverKextIdentifier; + CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1, + &kCFTypeArrayCallBacks); + CFDictionaryRef kext_infos = + KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); + CFRelease(kext_id_query); + + CFDictionaryRef cuda_driver_info = nullptr; + if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, + (const void **)&cuda_driver_info)) { + // NOTE: OSX CUDA driver does not currently store the same driver version + // in kCFBundleVersionKey as is returned by cuDriverGetVersion + CFRelease(kext_infos); + const CFStringRef str = (CFStringRef)CFDictionaryGetValue( + cuda_driver_info, kCFBundleVersionKey); + const char *version = CFStringGetCStringPtr(str, kCFStringEncodingUTF8); + + // version can be NULL in which case treat it as empty string + // see + // https://developer.apple.com/library/mac/documentation/CoreFoundation/Conceptual/CFStrings/Articles/AccessingContents.html#//apple_ref/doc/uid/20001184-100980-TPXREF112 + if (version == NULL) { + return StringToDriverVersion(""); + } + return StringToDriverVersion(version); + } + CFRelease(kext_infos); + auto status = port::Status( + port::error::INTERNAL, + port::StrCat( + "failed to read driver bundle version: ", + CFStringGetCStringPtr(kDriverKextIdentifier, kCFStringEncodingUTF8))); + return status; +#elif defined(PLATFORM_WINDOWS) auto status = port::Status(port::error::UNIMPLEMENTED, "kernel reported driver version not implemented on Windows"); diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index edf217875ff30b26407e35a39d1ee603b980b480..f11022ef1dfd4a1a08d035f5328724d93ac808be 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" +#if defined(__APPLE__) +#include +#endif #if defined(PLATFORM_WINDOWS) #include #define PATH_MAX MAX_PATH @@ -176,11 +179,19 @@ bool CUDAExecutor::FindOnDiskForComputeCapability( // would return /usr/bin. static string GetBinaryDir(bool strip_exe) { char exe_path[PATH_MAX] = {0}; +#if defined(__APPLE__) + uint32_t buffer_size = 0U; + _NSGetExecutablePath(nullptr, &buffer_size); + char unresolved_path[buffer_size]; + _NSGetExecutablePath(unresolved_path, &buffer_size); + CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1); +#else #if defined(PLATFORM_WINDOWS) HMODULE hModule = GetModuleHandle(NULL); GetModuleFileName(hModule, exe_path, MAX_PATH); #else CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1)); +#endif #endif // Make sure it's null-terminated: exe_path[sizeof(exe_path) - 1] = 0; @@ -843,7 +854,10 @@ CudaContext* CUDAExecutor::cuda_context() { return context_; } // For anything more complicated/prod-focused than this, you'll likely want to // turn to gsys' topology modeling. static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) { -#if defined(PLATFORM_WINDOWS) +#if defined(__APPLE__) + LOG(INFO) << "OS X does not support NUMA - returning NUMA node zero"; + return 0; +#elif defined(PLATFORM_WINDOWS) // Windows support for NUMA is not currently implemented. Return node 0. return 0; #elif defined(__aarch64__) diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc index 2c4819651acaa2c6ee99c720b2c3d80e5c2ea1a9..c8a629733006e17b7642a59afb8e0cb468f2c538 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.cc +++ b/tensorflow/stream_executor/host/host_gpu_executor.cc @@ -95,7 +95,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream, // the nature of the HostExecutor) memcpy on the stream (HostStream) // associated with the HostExecutor. AsHostStream(stream)->EnqueueTask( - [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); }); + [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); return true; } diff --git a/tensorflow/compiler/xla/statusor.cc b/tensorflow/stream_executor/lib/statusor.cc similarity index 89% rename from tensorflow/compiler/xla/statusor.cc rename to tensorflow/stream_executor/lib/statusor.cc index 72ab67ff810e0ec384a22da092363cc7446435bb..e0e851f96ef6fe18ec32ff7d3fd1d1aed18b0343 100644 --- a/tensorflow/compiler/xla/statusor.cc +++ b/tensorflow/stream_executor/lib/statusor.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" -namespace xla { +namespace stream_executor { +namespace port { namespace internal_statusor { void Helper::HandleInvalidStatusCtorArg(Status* status) { @@ -35,4 +36,5 @@ void Helper::Crash(const Status& status) { } } // namespace internal_statusor -} // namespace xla +} // namespace port +} // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/statusor.h b/tensorflow/stream_executor/lib/statusor.h index dab59096740102b94c0ff63c089b83ce052ea264..3c716acb462f1ca25e1d86408386d9eca37265b7 100644 --- a/tensorflow/stream_executor/lib/statusor.h +++ b/tensorflow/stream_executor/lib/statusor.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,19 +13,297 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" - +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. +// +// The primary use-case for StatusOr is as the return value of a +// function which may fail. +// +// Example client usage for a StatusOr, where T is not a pointer: +// +// StatusOr result = DoBigCalculationThatCouldFail(); +// if (result.ok()) { +// float answer = result.ValueOrDie(); +// printf("Big calculation yielded: %f", answer); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr: +// +// StatusOr result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr>: +// +// StatusOr> result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo = std::move(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example factory implementation returning StatusOr: +// +// StatusOr FooFactory::MakeNewFoo(int arg) { +// if (arg <= 0) { +// return tensorflow::InvalidArgument("Arg must be positive"); +// } else { +// return new Foo(arg); +// } +// } +// +// Note that the assignment operators require that destroying the currently +// stored value cannot invalidate the argument; in other words, the argument +// cannot be an alias for the current value, or anything owned by the current +// value. #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor_internals.h" namespace stream_executor { namespace port { -// Use XLA's StatusOr so we don't duplicate code. +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +template +class TF_MUST_USE_RESULT StatusOr; +#endif + +template +class StatusOr : private internal_statusor::StatusOrData, + private internal_statusor::TraitsBase< + std::is_copy_constructible::value, + std::is_move_constructible::value> { + template + friend class StatusOr; + + typedef internal_statusor::StatusOrData Base; + + public: + typedef T element_type; + + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // StatusOr> will be initialized with an empty vector, + // instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr will be copy constructible/assignable if T is copy + // constructible. + StatusOr(const StatusOr&) = default; + StatusOr& operator=(const StatusOr&) = default; + + // StatusOr will be move constructible/assignable if T is move + // constructible. + StatusOr(StatusOr&&) = default; + StatusOr& operator=(StatusOr&&) = default; + + // Conversion copy/move constructor, T must be convertible from U. + template ::value>::type* = nullptr> + StatusOr(const StatusOr& other); + template ::value>::type* = nullptr> + StatusOr(StatusOr&& other); + + // Conversion copy/move assignment operator, T must be convertible from U. + template ::value>::type* = nullptr> + StatusOr& operator=(const StatusOr& other); + template ::value>::type* = nullptr> + StatusOr& operator=(StatusOr&& other); + + // Constructs a new StatusOr with the given value. After calling this + // constructor, calls to ValueOrDie() will succeed, and calls to status() will + // return OK. + // + // NOTE: Not explicit - we want to use StatusOr as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr. + // + // REQUIRES: T is copy constructible. + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling + // this constructor, calls to ValueOrDie() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing Status::OK() here will have the effect + // of passing tensorflow::error::INTERNAL as a fallback. + StatusOr(const Status& status); + StatusOr& operator=(const Status& status); + + // TODO(b/62186997): Add operator=(T) overloads. + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(Status&& status); + StatusOr& operator=(Status&& status); + + // Returns this->status().ok() + bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns Status::OK(). + const Status& status() const &; + Status status() &&; + + // Returns a reference to our current value, or CHECK-fails if !this->ok(). + // + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.ValueOrDie(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.ValueOrDie(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).ValueOrDie(); + // + // The std::move on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. + // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 + // See go/ref-qualifiers for more details on such overloads. + const T& ValueOrDie() const &; + T& ValueOrDie() &; + const T&& ValueOrDie() const &&; + T&& ValueOrDie() &&; + + T ConsumeValueOrDie() { return std::move(ValueOrDie()); } + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr + +template +StatusOr::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {} + +template +StatusOr::StatusOr(const T& value) : Base(value) {} + +template +StatusOr::StatusOr(const Status& status) : Base(status) {} + +template +StatusOr& StatusOr::operator=(const Status& status) { + this->Assign(status); + return *this; +} + +template +StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} + +template +StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} + +template +StatusOr& StatusOr::operator=(Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template +template ::value>::type*> +inline StatusOr::StatusOr(const StatusOr& other) + : Base(static_cast::Base&>(other)) {} + +template +template ::value>::type*> +inline StatusOr& StatusOr::operator=(const StatusOr& other) { + if (other.ok()) + this->Assign(other.ValueOrDie()); + else + this->Assign(other.status()); + return *this; +} + +template +template ::value>::type*> +inline StatusOr::StatusOr(StatusOr&& other) + : Base(static_cast::Base&&>(other)) {} + +template +template ::value>::type*> +inline StatusOr& StatusOr::operator=(StatusOr&& other) { + if (other.ok()) { + this->Assign(std::move(other).ValueOrDie()); + } else { + this->Assign(std::move(other).status()); + } + return *this; +} + +template +const Status& StatusOr::status() const & { + return this->status_; +} +template +Status StatusOr::status() && { + return ok() ? Status::OK() : std::move(this->status_); +} + +template +const T& StatusOr::ValueOrDie() const & { + this->EnsureOk(); + return this->data_; +} + +template +T& StatusOr::ValueOrDie() & { + this->EnsureOk(); + return this->data_; +} + +template +const T&& StatusOr::ValueOrDie() const && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +T&& StatusOr::ValueOrDie() && { + this->EnsureOk(); + return std::move(this->data_); +} + template -using StatusOr = ::xla::StatusOr; +void StatusOr::IgnoreError() const { + // no-op +} } // namespace port } // namespace stream_executor diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/stream_executor/lib/statusor_internals.h similarity index 94% rename from tensorflow/compiler/xla/statusor_internals.h rename to tensorflow/stream_executor/lib/statusor_internals.h index 14636bd144bc0a155fc96c5a350c658fd2dadfe6..09f88f5825f57c8e654bd079616a074e84de4f30 100644 --- a/tensorflow/compiler/xla/statusor_internals.h +++ b/tensorflow/stream_executor/lib/statusor_internals.h @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ -#define TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ + -#include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/status.h" -namespace xla { +namespace stream_executor { +namespace port { namespace internal_statusor { class Helper { @@ -240,6 +242,7 @@ struct TraitsBase { }; } // namespace internal_statusor -} // namespace xla +} // namespace port +} // namespace stream_executor -#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc similarity index 99% rename from tensorflow/compiler/xla/statusor_test.cc rename to tensorflow/stream_executor/lib/statusor_test.cc index 377a618ffbd99316d409130df8a39f352664dee0..56584e189208b2576f10650fd56bca6d04ecc6c1 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/stream_executor/lib/statusor_test.cc @@ -15,18 +15,18 @@ limitations under the License. // Unit tests for StatusOr -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include #include -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test_benchmark.h" -namespace xla { +namespace stream_executor { +namespace port { namespace { class Base1 { @@ -672,4 +672,5 @@ void BM_StatusOrFactoryFailLongMsg(int iters) { BENCHMARK(BM_StatusOrFactoryFailLongMsg); } // namespace -} // namespace xla +} // namespace port +} // namespace stream_executor diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 3da1b856d6a41fa0c8d5a77feac33932da392422..e8885e1eb682d9ee67c6b7594f96c0911c7c1fa2 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "tensorflow/core/platform/macros.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/dnn.h" @@ -156,14 +157,13 @@ class Stream { const TypedKernel &kernel, Args... args); // Record a "start" event for the interval timer at this point in the - // stream's - // execution (relative to the previously and subsequently enqueued items in - // the stream's execution). Streams may be started/stopped multiple times. + // stream's execution (relative to the previously and subsequently enqueued + // items in the stream's execution). Streams may be started/stopped multiple + // times. Stream &ThenStartTimer(Timer *t); // Record a "stop" event for the interval timer at this point in the - // stream's - // execution. See also Stream::ThenStartTimer. + // stream's execution. See also Stream::ThenStartTimer. Stream &ThenStopTimer(Timer *t); // TODO(leary) If work is added to the stream that is being depended upon, @@ -179,8 +179,7 @@ class Stream { // // Checks that a stream does not wait for itself, and it is up to the // user to guarantee that a stream does not come to wait on itself in a - // cyclic - // manner; in that case, behavior is undefined. + // cyclic manner; in that case, behavior is undefined. // // N.B. Base recursion case for the variadic ThenWaitFor. Stream &ThenWaitFor(Stream *other); @@ -1351,33 +1350,39 @@ class Stream { DeviceMemory> *x, int incx); // See BlasSupport::DoBlasGemm. - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, float beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, float beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, double beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &b, int ldb, - std::complex beta, - DeviceMemory> *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &b, int ldb, - std::complex beta, - DeviceMemory> *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + float beta, DeviceMemory *c, + int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + float beta, DeviceMemory *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, double alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + double beta, DeviceMemory *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, + std::complex alpha, + const DeviceMemory> &a, + int lda, + const DeviceMemory> &b, + int ldb, std::complex beta, + DeviceMemory> *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, + std::complex alpha, + const DeviceMemory> &a, + int lda, + const DeviceMemory> &b, + int ldb, std::complex beta, + DeviceMemory> *c, + int ldc); Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index b59f8e1f987567727ef3d4051618edd377d06f89..e4632c48112d40fb96b4c2b510da93678b11efc4 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -148,6 +148,12 @@ def if_windows(a): "//conditions:default": [], }) +def if_not_windows_cuda(a): + return select({ + clean_dep("//tensorflow:with_cuda_support_windows_override"): [], + "//conditions:default": a, + }) + def if_linux_x86_64(a): return select({ clean_dep("//tensorflow:linux_x86_64"): a, @@ -241,6 +247,9 @@ def tf_opts_nortti_if_android(): # LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) +def tf_features_nomodules_if_android(): + return if_android(["-use_header_modules"]) + # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate a library for that file. def tf_gen_op_libs(op_lib_names, deps=None, is_external=True): @@ -919,6 +928,7 @@ def tf_gpu_kernel_library(srcs, hdrs=[], **kwargs): copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts() + kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] native.cc_library( srcs=srcs, @@ -959,6 +969,7 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs): if not cuda_deps: cuda_deps = [] + kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] native.cc_library( deps=deps + if_cuda(cuda_deps + [ clean_dep("//tensorflow/core:cuda"), @@ -1301,6 +1312,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): name=basename + "_gpu", srcs=gpu_srcs, copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), + features = if_cuda(["-use_header_modules"]), deps=deps + if_cuda(cuda_deps)) cuda_deps.extend([":" + basename + "_gpu"]) diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index 6065c12cadb86f5f55a4c799bd163d7c561cd1b7..8c760e6f52598a5e7399c9250adf99283572d3a4 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -3,38 +3,37 @@ licenses(["notice"]) # Apache 2.0 -exports_files(["LICENSE"]) - load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES") -py_library( - name = "doc_srcs", - srcs = ["doc_srcs.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:util", +exports_files( + [ + "LICENSE", + "create_python_api.py", ], ) -py_binary( - name = "create_python_api", - srcs = ["create_python_api.py"], +py_library( + name = "doc_srcs", + srcs = ["doc_srcs.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":doc_srcs", - "//tensorflow/python:no_contrib", + "//tensorflow/python:util", ], ) py_test( name = "create_python_api_test", - srcs = ["create_python_api_test.py"], + srcs = [ + "create_python_api.py", + "create_python_api_test.py", + ], srcs_version = "PY2AND3", deps = [ - ":create_python_api", + ":doc_srcs", "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", ], ) @@ -67,5 +66,6 @@ py_test( ":doc_srcs", "//tensorflow/python:client_testlib", "//tensorflow/python:no_contrib", + "//tensorflow/python/estimator:estimator_py", ], ) diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl index 41713a94ecbde567340a9e1571a06aca3bbda97a..d746b5d3e4f7745d78563eac65ccdf822511a7ef 100644 --- a/tensorflow/tools/api/generator/api_gen.bzl +++ b/tensorflow/tools/api/generator/api_gen.bzl @@ -8,13 +8,16 @@ TENSORFLOW_API_INIT_FILES = [ "bitwise/__init__.py", "compat/__init__.py", "data/__init__.py", + "debugging/__init__.py", "distributions/__init__.py", "distributions/bijectors/__init__.py", + "dtypes/__init__.py", "errors/__init__.py", "feature_column/__init__.py", "gfile/__init__.py", "graph_util/__init__.py", "image/__init__.py", + "io/__init__.py", "initializers/__init__.py", "keras/__init__.py", "keras/activations/__init__.py", @@ -65,6 +68,7 @@ TENSORFLOW_API_INIT_FILES = [ "nn/rnn_cell/__init__.py", "profiler/__init__.py", "python_io/__init__.py", + "quantization/__init__.py", "resource_loader/__init__.py", "strings/__init__.py", "saved_model/__init__.py", @@ -114,22 +118,44 @@ ESTIMATOR_API_INIT_FILES = [ # template will be replaced with root imports collected by this genrule. # srcs: genrule sources. If passing root_init_template, the template file # must be included in sources. -def gen_api_init_files(name, - output_files=TENSORFLOW_API_INIT_FILES, - root_init_template=None, - srcs=[], - api_name="tensorflow", - package="tensorflow.python"): - root_init_template_flag = "" - if root_init_template: - root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" - native.genrule( - name = name, - outs = output_files, - cmd = ( - "$(location //tensorflow/tools/api/generator:create_python_api) " + - root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"), - srcs = srcs, - tools = ["//tensorflow/tools/api/generator:create_python_api"], - visibility = ["//tensorflow:__pkg__"], - ) +# api_name: Name of the project that you want to generate API files for +# (e.g. "tensorflow" or "estimator"). +# package: Python package containing the @tf_export decorators you want to +# process +# package_dep: Python library target containing your package. + +def gen_api_init_files( + name, + output_files = TENSORFLOW_API_INIT_FILES, + root_init_template = None, + srcs = [], + api_name = "tensorflow", + package = "tensorflow.python", + package_dep = "//tensorflow/python:no_contrib"): + root_init_template_flag = "" + if root_init_template: + root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" + + api_gen_binary_target = "create_" + package + "_api" + native.py_binary( + name = "create_" + package + "_api", + srcs = ["//tensorflow/tools/api/generator:create_python_api.py"], + main = "//tensorflow/tools/api/generator:create_python_api.py", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + package_dep, + "//tensorflow/tools/api/generator:doc_srcs", + ], + ) + + native.genrule( + name = name, + outs = output_files, + cmd = ( + "$(location :" + api_gen_binary_target + ") " + + root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"), + srcs = srcs, + tools = [":" + api_gen_binary_target ], + visibility = ["//tensorflow:__pkg__"], + ) diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index 671b7e387e4bfc56aa8ec4d541d0170e0c96baa9..48d7dcd09eb38f53031afde70fe2e1a9b660ad1a 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -180,7 +180,7 @@ def get_api_init_text(package, api_name): for module in list(sys.modules.values()): # Only look at tensorflow modules. if (not module or not hasattr(module, '__name__') or - package not in module.__name__): + module.__name__ is None or package not in module.__name__): continue # Do not generate __init__.py files for contrib modules for now. if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/tools/api/generator/doc_srcs.py index ccd5bea481b000d5d769b37f0fd848bb73daf056..ad1988494dae4a9d3ee96af5af76f02c52c0dff4 100644 --- a/tensorflow/tools/api/generator/doc_srcs.py +++ b/tensorflow/tools/api/generator/doc_srcs.py @@ -43,7 +43,7 @@ _TENSORFLOW_DOC_SOURCES = { 'gfile': DocSource(docstring_module_name='platform.gfile'), 'graph_util': DocSource(docstring_module_name='framework.graph_util'), 'image': DocSource(docstring_module_name='ops.image_ops'), - 'keras.estimator': DocSource(docstring_module_name='estimator.keras'), + 'keras.estimator': DocSource(docstring_module_name='keras.estimator'), 'linalg': DocSource(docstring_module_name='ops.linalg_ops'), 'logging': DocSource(docstring_module_name='ops.logging_ops'), 'losses': DocSource(docstring_module_name='ops.losses.losses'), diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/tools/api/generator/doc_srcs_test.py index 7b8f27c1b1cd474462d0eab2ddc8451bb256496b..dbff904abe6251ad180140c4c7c404f051b17d55 100644 --- a/tensorflow/tools/api/generator/doc_srcs_test.py +++ b/tensorflow/tools/api/generator/doc_srcs_test.py @@ -39,27 +39,27 @@ class DocSrcsTest(test.TestCase): file_path += '/' file_path += '__init__.py' - if file_path not in FLAGS.outputs: - self.assertFalse('%s is not a valid API module' % module_name) + self.assertIn( + file_path, FLAGS.outputs, + msg='%s is not a valid API module' % module_name) def testHaveDocstringOrDocstringModule(self): for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): - if docsrc.docstring and docsrc.docstring_module_name: - self.assertFalse( - '%s contains DocSource has both a docstring and a ' - 'docstring_module_name. ' - 'Only one of "docstring" or "docstring_module_name" should be set.' - % (module_name)) + self.assertFalse( + docsrc.docstring and docsrc.docstring_module_name, + msg=('%s contains DocSource has both a docstring and a ' + 'docstring_module_name. Only one of "docstring" or ' + '"docstring_module_name" should be set.') % (module_name)) def testDocstringModulesAreValidModules(self): for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): if docsrc.docstring_module_name: doc_module_name = '.'.join([ FLAGS.package, docsrc.docstring_module_name]) - if doc_module_name not in sys.modules: - self.assertFalse( - 'docsources_module %s is not a valid module under %s.' % - (docsrc.docstring_module_name, FLAGS.package)) + self.assertIn( + doc_module_name, sys.modules, + msg=('docsources_module %s is not a valid module under %s.' % + (docsrc.docstring_module_name, FLAGS.package))) if __name__ == '__main__': diff --git a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt index f819b174c0b701153af4709fade9313efa7f7fb6..353e63127de174a79c209a05327da2de20bf0dd7 100644 --- a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt @@ -72,6 +72,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "num_dev_to_dev_copy_streams" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } nested_type { name: "VirtualDevices" field { diff --git a/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9efe97821904f5891148b72a0c31e02c9562bd7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.debugging" +tf_module { + member_method { + name: "check_numerics" + argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_finite" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_inf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_nan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..98e1feed002ceb4f455aa5ec361d26a159fdad1a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.dtypes" +tf_module { + member_method { + name: "as_string" + argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index e268fa3f618adda5323692823070c679726117be..e89b4dbffdfe85f471fb1dd1b976cc701d526c64 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -84,6 +84,10 @@ tf_module { name: "extract_glimpse" argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], " } + member_method { + name: "extract_image_patches" + argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "extract_jpeg_shape" argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt index a6b6e5eceb62654c9ad567a361f7558a2865e57a..86340913e2506c96499aae05a3ed0d5273c93bba 100644 --- a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/tensorflow.io.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3a36c168aa703721421b662185fc852fa3d6a3ec --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.io.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.io" +tf_module { + member_method { + name: "decode_base64" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_compressed" + argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "decode_json_example" + argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_raw" + argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "encode_base64" + argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "matching_files" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "parse_tensor" + argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_file" + argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "write_file" + argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt index 7b0ad85eaac5b83835a9e1c4b152e38e7051a2f6..f71292856cd29b2e52194bec8a586686fbfad667 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\'], " + argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\', \'baseline\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\', \'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt index 32a6f6ee88815b3dc70e9cca855f73099554953b..03f4064b9ef5093044a9cbb897043d643cf7f83e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f3a96ab895dc9dbf8e2362dbcbfdccdf6af749ec --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ReLU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_value\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt index 475e9dade3a6f5ba8a5020afbd4668be5b5ed9d7..9d7e5bb8c7808689bedd8abb835e61c1f38fdb1d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt @@ -300,6 +300,10 @@ tf_module { name: "RNN" mtype: "" } + member { + name: "ReLU" + mtype: "" + } member { name: "RepeatVector" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt index 00b9238543367546cff96b736f73440214e99e22..3b5845f99a474ed976b91dab4f80ac2f231e7fc1 100644 --- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt @@ -68,6 +68,10 @@ tf_module { name: "cholesky_solve" argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "cross" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "det" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -140,6 +144,14 @@ tf_module { name: "svd" argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " } + member_method { + name: "tensor_diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensor_diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "tensordot" argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt index 0b84165285102daf0a8e3dd6542bfc391e50f77b..9add462396ea526ae94678e969c9acf5bce86df1 100644 --- a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt @@ -1,7 +1,35 @@ path: "tensorflow.manip" tf_module { + member_method { + name: "batch_to_space_nd" + argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "gather_nd" + argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reshape" + argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "roll" argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "scatter_nd" + argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch_nd" + argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tile" + argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt index 03fbf6266d24f6a31ef76e2668c15987e55d2651..a308c76ebc08df06c0c360579451ea70e60695d4 100644 --- a/tensorflow/tools/api/golden/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.math.pbtxt @@ -1,8 +1,40 @@ path: "tensorflow.math" tf_module { + member_method { + name: "acos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "acosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan2" + argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "bessel_i0" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i0\'], " + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "bessel_i0e" @@ -10,14 +42,198 @@ tf_module { } member_method { name: "bessel_i1" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i1\'], " + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "bessel_i1e" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "betainc" + argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ceil" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "digamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erfc" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "exp" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "expm1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floor" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igammac" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert_permutation" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lgamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log1p" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_not" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "minimum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "not_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polygamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "polyval" argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "reciprocal" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rint" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rsqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softsign" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squared_difference" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zeta" + argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 01b80581188e65d228aaa669254d9951546ecfa0..adab5399b21d1133a6cf6c45cb963834ff49f417 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -308,6 +308,10 @@ tf_module { name: "data" mtype: "" } + member { + name: "debugging" + mtype: "" + } member { name: "distributions" mtype: "" @@ -316,6 +320,10 @@ tf_module { name: "double" mtype: "" } + member { + name: "dtypes" + mtype: "" + } member { name: "errors" mtype: "" @@ -380,6 +388,10 @@ tf_module { name: "int8" mtype: "" } + member { + name: "io" + mtype: "" + } member { name: "keras" mtype: "" @@ -456,6 +468,10 @@ tf_module { name: "qint8" mtype: "" } + member { + name: "quantization" + mtype: "" + } member { name: "quint16" mtype: "" @@ -1294,7 +1310,7 @@ tf_module { } member_method { name: "lbeta" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'lbeta\'], " + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "less" diff --git a/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6d865efed0bfdada8dde64e86ddb5d2b2b364c79 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.quantization" +tf_module { + member_method { + name: "dequantize" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "quantized_concat" + argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt index 4f306540ccfdeac8ce59a394ec77b24284f13ceb..6a421ef12d58dc047905ec916cbe777b4ce19b9a 100644 --- a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt @@ -16,6 +16,10 @@ tf_module { name: "fft3d" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "idct" + argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], " + } member_method { name: "ifft" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt index b641c39feb6bcc4b5b73ba81ce0f0d4a499007ea..9a831fed2692b30db6ce991c86f46a42908c0789 100644 --- a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt @@ -1,11 +1,43 @@ path: "tensorflow.strings" tf_module { + member_method { + name: "join" + argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } member_method { name: "regex_full_match" argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "regex_replace" + argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } member_method { name: "split" argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " } + member_method { + name: "strip" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "substr" + argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket" + argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_fast" + argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_strong" + argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_number" + argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt index ddc553d7c984b24fe33c03bb90e00e7e81f55d26..2d067e4eff13208cb03ca01b7b8a8018a1e99097 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.train.Checkpoint" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt index a58398d645e8397dc8e61a6e0241710c3e34218f..09d7bc03b4f238923db6778ec32ce78ae76eed61 100644 --- a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le new file mode 100644 index 0000000000000000000000000000000000000000..e879c34bbdadd7b90973fda0f7c3fdb71a385856 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le @@ -0,0 +1,20 @@ +FROM ubuntu:16.04 + +LABEL maintainer="William Irons " + +# Copy and run the install scripts. +COPY install/*.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa +RUN /install/install_deb_packages.sh +RUN apt-get update && apt-get install -y libopenblas-dev +RUN /install/install_hdf5_ppc64le.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel_from_source.sh +RUN /install/install_proto3.sh +RUN /install/install_buildifier_from_source.sh +RUN /install/install_auditwheel.sh +RUN /install/install_golang_ppc64le.sh + +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le new file mode 100644 index 0000000000000000000000000000000000000000..89671387472a15c112a09fa2fa7a9798446d135b --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le @@ -0,0 +1,28 @@ +FROM nvidia/cuda-ppc64le:9.0-cudnn7-devel-ubuntu16.04 + +LABEL maintainer="William Irons " + +# In the Ubuntu 16.04 images, cudnn is placed in system paths. Move them to +# /usr/local/cuda +RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include +RUN cp -P /usr/lib/powerpc64le-linux-gnu/libcudnn* /usr/local/cuda/lib64 + +# Copy and run the install scripts. +COPY install/*.sh /install/ +ARG DEBIAN_FRONTEND=noninteractive +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa +RUN /install/install_deb_packages.sh +RUN apt-get update && apt-get install -y libopenblas-dev +RUN /install/install_hdf5_ppc64le.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel_from_source.sh +RUN /install/install_golang_ppc64le.sh + +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc +ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH + +# Configure the build for our CUDA configuration. +ENV TF_NEED_CUDA 1 +ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0 diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu index 3bc52b9ed611a0f0a4a269a2864d5b349ee9232c..7e5860aeec186d908e5d2884bd690b2e5e43cffa 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu @@ -1,4 +1,4 @@ -FROM launcher.gcr.io/google/rbe-debian8:r327695 +FROM launcher.gcr.io/google/rbe-ubuntu16-04:r327695 LABEL maintainer="Yu Yi " # Copy install scripts @@ -9,6 +9,6 @@ ENV CC /usr/local/bin/clang ENV CXX /usr/local/bin/clang++ ENV AR /usr/bin/ar -# Run pip install script for RBE Debian8 container. +# Run pip install script for RBE Ubuntu 16-04 container. RUN /install/install_pip_packages_remote.sh RUN /install/install_pip_packages.sh diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh index 1f0fd0387af28bf15e5c42fa14f5c1a1ee5a8cfb..f6a50d3d4c4f948e37ff841a880b373f1034fd76 100755 --- a/tensorflow/tools/ci_build/ci_build.sh +++ b/tensorflow/tools/ci_build/ci_build.sh @@ -79,7 +79,7 @@ if [[ "${CONTAINER_TYPE}" == "cmake" ]]; then fi # Use nvidia-docker if the container is GPU. -if [[ "${CONTAINER_TYPE}" == "gpu" ]]; then +if [[ "${CONTAINER_TYPE}" == gpu* ]]; then DOCKER_BINARY="nvidia-docker" else DOCKER_BINARY="docker" @@ -99,7 +99,7 @@ BUILD_TAG="${BUILD_TAG:-tf_ci}" # Add extra params for cuda devices and libraries for GPU container. # And clear them if we are not building for GPU. -if [[ "${CONTAINER_TYPE}" != "gpu" ]]; then +if [[ "${CONTAINER_TYPE}" != gpu* ]]; then GPU_EXTRA_PARAMS="" fi diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 90bd8bc3d0c349b5df51535cadc3b1d85d76b7d0..d49d4b0c49cf9ed487249a800e3807140a9a03bf 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -59,6 +59,9 @@ # TF_BUILD_BAZEL_CLEAN: # Will perform "bazel clean", if and only if this variable # is set to any non-empty and non-0 value +# TF_BAZEL_BUILD_ONLY: +# If it is set to any non-empty value that is not "0", Bazel +# will only build specified targets # TF_GPU_COUNT: # Run this many parallel tests for serial builds. # For now, only can be edited for PIP builds. @@ -258,9 +261,9 @@ function set_script_variable() { # Process container type -if [[ ${CTYPE} == "cpu" ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then +if [[ ${CTYPE} == cpu* ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then : -elif [[ ${CTYPE} == "gpu" ]]; then +elif [[ ${CTYPE} == gpu* ]]; then set_script_variable TF_NEED_CUDA 1 if [[ $TF_CUDA_CLANG == "1" ]]; then @@ -410,6 +413,11 @@ fi # this flag, and it only affects a few tests. EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false" +if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] && + [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then + BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD} +fi + # Process PIP install-test option if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || [[ ${TF_BUILD_IS_PIP} == "both" ]]; then @@ -418,12 +426,12 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || BAZEL_TARGET=${TF_BUILD_BAZEL_TARGET} fi - if [[ ${CTYPE} == "cpu" ]] || \ + if [[ ${CTYPE} == cpu* ]] || \ [[ ${CTYPE} == "debian.jessie.cpu" ]]; then # CPU only command, fully parallel. NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} ${EXTRA_ARGS} -- "\ "${BAZEL_TARGET}" - elif [[ ${CTYPE} == "gpu" ]]; then + elif [[ ${CTYPE} == gpu* ]]; then # GPU only command, run as many jobs as the GPU count only. NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\ "--local_test_jobs=${TF_GPU_COUNT} "\ diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 05676f9551d4a1e0cb55d0693f99e458381887df..f0a437c1831378ebf11eb7869163d2512ac237f9 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -349,12 +349,12 @@ do_external_licenses_check(){ # Blacklist echo ${MISSING_LICENSES_FILE} - grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt mv temp.txt ${MISSING_LICENSES_FILE} # Whitelist echo ${EXTRA_LICENSE_FILE} - grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt mv temp.txt ${EXTRA_LICENSES_FILE} diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh new file mode 100755 index 0000000000000000000000000000000000000000..ddad00c5f01a78164903702b03c816c427aeb0b8 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script is to be used to install bzel on non x86_64 systems +# It will compile bazel from source and install it in /usr/local/bin + +# Select bazel version. +BAZEL_VERSION="0.11.0" + +set +e +local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') + +if [[ "$local_bazel_ver" == "$BAZEL_VERSION" ]]; then + exit 0 +fi + +set -e + +# Compile bazel from source +mkdir -p /bazel +cd /bazel + +curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip +unzip bazel-$BAZEL_VERSION-dist.zip +bash ./compile.sh +cp output/bazel /usr/local/bin/ +rm -rf /bazel diff --git a/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh new file mode 100755 index 0000000000000000000000000000000000000000..a93c258fad1ca62b0c95f22560110ba231aa0053 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +BUILDTOOLS_VERSION="0.11.1" + +# Clone buildtools +git clone -b $BUILDTOOLS_VERSION https://github.com/bazelbuild/buildtools +cd buildtools + +# Build buildifier +bazel build //buildifier +sudo mv bazel-bin/buildifier/linux*stripped/buildifier /usr/local/bin + +# Build buildozer +bazel build //buildozer +sudo mv bazel-bin/buildozer/linux*stripped/buildozer /usr/local/bin diff --git a/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..47d23a59b3ee9152ef9812fbe939e20ee7c2b40a --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +GOLANG_URL="https://storage.googleapis.com/golang/go1.10.linux-ppc64le.tar.gz" + +sudo mkdir -p /usr/local +wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz diff --git a/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..4989d986b8eb0690f63ecff41f7107371724bc3a --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +#This is required because pypi doesn't have a pre-built h5py binary for ppc64le +#It has to be compiled from source during the install +apt-get update +apt-get install -y libhdf5-dev + +#h5py is not expecting the shared libraries to have _serial in the name. +ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial.so /usr/lib/powerpc64le-linux-gnu/libhdf5.so +ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial_hl.so /usr/lib/powerpc64le-linux-gnu/libhdf5_hl.so + +#pip is not installed yet, so use easy_install +#CPATH is the location of hdf5.h +CPATH=/usr/include/hdf5/serial/ easy_install -U h5py +CPATH=/usr/include/hdf5/serial/ easy_install3 -U h5py diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 386e66cc212a934ace7814ab3961863e741b6915..221b5b80fb48979af09cb99a5c35cbe5fc4e5ca1 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -51,8 +51,8 @@ pip2 install --upgrade markdown==2.6.8 pip3 install --upgrade markdown==2.6.8 # Install protobuf. -pip2 install --upgrade protobuf==3.3.0 -pip3 install --upgrade protobuf==3.3.0 +pip2 install --upgrade protobuf==3.6.0 +pip3 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -119,3 +119,7 @@ pip2 install keras_applications==1.0.2 pip3 install keras_applications==1.0.2 pip2 install keras_preprocessing==1.0.1 pip3 install keras_preprocessing==1.0.1 + +# Install last working version of setuptools. +pip2 install --upgrade setuptools==39.1.0 +pip3 install --upgrade setuptools==39.1.0 diff --git a/tensorflow/tools/ci_build/install/install_proto3.sh b/tensorflow/tools/ci_build/install/install_proto3.sh index 7934002b2c982cd10216016f8614b70b77b58e29..821d50baff325106fceca368d46042401d13c336 100755 --- a/tensorflow/tools/ci_build/install/install_proto3.sh +++ b/tensorflow/tools/ci_build/install/install_proto3.sh @@ -17,7 +17,7 @@ # Install protobuf3. # Select protobuf version. -PROTOBUF_VERSION="3.3.0" +PROTOBUF_VERSION="3.6.0" protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g') local_protobuf_ver=$(protoc --version) local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g') diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 4e28fa74b9bf45aced635f6436cc78f4ecdefba5..45a30c6e82c336a0171c7602e09f2184f1459175 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -48,7 +48,7 @@ pip3.5 install --upgrade absl-py pip3.5 install --upgrade six==1.10.0 # Install protobuf. -pip3.5 install --upgrade protobuf==3.3.0 +pip3.5 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -88,4 +88,7 @@ pip3.5 install --upgrade setuptools==39.1.0 pip3.5 install keras_applications==1.0.2 pip3.5 install keras_preprocessing==1.0.1 +# Install last working version of setuptools. +pip3.5 install --upgrade setuptools==39.1.0 + # LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh index a0b43199a2726bd9b9497daf7c489e9b6dc1f784..d66b2aa18a7d77dd697031cfd2616712d586280a 100755 --- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -60,7 +60,7 @@ pip3 install --upgrade absl-py pip3 install --upgrade six==1.10.0 # Install protobuf. -pip3 install --upgrade protobuf==3.3.0 +pip3 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* diff --git a/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh new file mode 100755 index 0000000000000000000000000000000000000000..ad22ebe4eb304fe6b6f8613f43f2c7c001111503 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Build a whl and container with Intel(R) MKL support +# Usage: build-dev-container.sh + +# Helper function to traverse directories up until given file is found. +function upsearch () { + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" +} + +# Set up WORKSPACE. +WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" + +TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH:-master} +TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME:-intel-mkl/tensorflow} +TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION:-nightly} + +echo "TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH}" +echo "TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}" +echo "TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}" + +# build the python 2 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + +# build the python 3 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \ + TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh index b8bce57c87ab39ab2f51288163187f2e87c9135d..3d27e84b81c586729aff21d0859383c24f436a11 100755 --- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh +++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh @@ -65,6 +65,10 @@ OPENBLAS_SRC_PATH=/tmp/openblas_src/ sudo rm -rf ${OPENBLAS_SRC_PATH} git clone https://github.com/xianyi/OpenBLAS ${OPENBLAS_SRC_PATH} cd ${OPENBLAS_SRC_PATH} +# The commit after this introduced Fortran compile issues. In theory they should +# be solvable using NOFORTRAN=1 on the make command, but my initial tries didn't +# work, so pinning to the last know good version. +git checkout 5a6a2bed9aff0ba8a18651d5514d029c8cae336a # If this path is changed, you'll also need to update # cxx_builtin_include_directory in third_party/toolchains/cpus/arm/CROSSTOOL.tpl OPENBLAS_INSTALL_PATH=/tmp/openblas_install/ diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index 00bfcfd49bd1d90dccf094de21173ca9e4307319..642dde36a7caae35df764d5d7513df972e1e5615 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -37,7 +37,7 @@ SETUP_PY = "%s/tools/pip_package/setup.py" % TF_SRC_DIR README_MD = "./README.md" DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel" % TF_SRC_DIR GPU_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-gpu" % TF_SRC_DIR -CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-cpu-mkl" % TF_SRC_DIR +CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-mkl" % TF_SRC_DIR RELEVANT_FILES = [TF_SRC_DIR, VERSION_H, SETUP_PY, diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index eefa8ee2d504945991c91e1574b6a74330ba3a8d..8a237e4e28376771742ba93b795950d368660196 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -49,3 +49,15 @@ export PATH="/c/Program Files/Git/cmd:$PATH" # Make sure we have pip in PATH export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH" + +# Setting default values to CUDA related environment variables +export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0} +export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0} +export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7} +export CUDA_INSTALL_PATH=${CUDA_INSTALL_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"} +export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"} + +# Add Cuda and Cudnn dll directories into PATH +export PATH="$(cygpath -u "${CUDA_INSTALL_PATH}")/bin:$PATH" +export PATH="$(cygpath -u "${CUDA_INSTALL_PATH}")/extras/CUPTI/libx64:$PATH" +export PATH="$(cygpath -u "${CUDNN_INSTALL_PATH}")/bin:$PATH" diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh index 0b13b97209fa6cd6c629a64fdd54a0423535a9a3..5c305f7512852dd6b3e43c4745e7f24c8a4502aa 100644 --- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh @@ -77,7 +77,12 @@ fi # to distinct them. This helps avoid building the same targets twice. echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}" -echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +# Enable short object file path to avoid long path issue on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" + +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi run_configure_for_cpu_build diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh index 922bb67bbf6ce34f55acad6d3399bd810032abd0..ededad615aa7e8ef5ef3c050bc36141523db1c71 100644 --- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh @@ -42,9 +42,58 @@ source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \ source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \ || { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; } +# Recreate an empty bazelrc file under source root +export TMP_BAZELRC=.tmp.bazelrc +rm -f "${TMP_BAZELRC}" +touch "${TMP_BAZELRC}" + +function cleanup { + # Remove all options in .tmp.bazelrc + echo "" > "${TMP_BAZELRC}" +} +trap cleanup EXIT + +skip_test=0 +release_build=0 + +for ARG in "$@"; do + if [[ "$ARG" == --skip_test ]]; then + skip_test=1 + elif [[ "$ARG" == --enable_gcs_remote_cache ]]; then + set_gcs_remote_cache_options + elif [[ "$ARG" == --release_build ]]; then + release_build=1 + fi +done + +if [[ "$release_build" != 1 ]]; then + # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc + # by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521 + # Because this hurts the performance of TF, we don't enable it in release build. + echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}" +fi + +# The host and target platforms are the same in Windows build. So we don't have +# to distinct them. This helps avoid building the same targets twice. +echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}" + +# Enable short object file path to avoid long path issue on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" + +# Disable nvcc warnings to reduce log file size. +echo "build --copt=-nvcc_options=disable-warnings" >> "${TMP_BAZELRC}" + +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi + run_configure_for_gpu_build -bazel build -c opt tensorflow/tools/pip_package:build_pip_package || exit $? +bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $? + +if [[ "$skip_test" == 1 ]]; then + exit 0 +fi # Create a python test directory to avoid package name conflict PY_TEST_DIR="py_test_dir" @@ -59,8 +108,11 @@ reinstall_tensorflow_pip ${PIP_NAME} # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow # GPU tests are very flaky when running concurrently, so set local_test_jobs=1 -bazel test -c opt -k --test_output=errors \ +bazel test --announce_rc --config=opt -k --test_output=errors \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ - --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \ - --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \ - --local_test_jobs=1 --build_tests_only //${PY_TEST_DIR}/tensorflow/python/... + --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss \ + --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss --build_tests_only \ + --local_test_jobs=1 --test_timeout="300,450,1200,3600" \ + --flaky_test_attempts=3 \ + //${PY_TEST_DIR}/tensorflow/python/... \ + //${PY_TEST_DIR}/tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh index 583d1d5f09527861015458c636af2259b34d45f8..fdbd1120b20ea4461a4ec5f84c666d8b62309905 100755 --- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -41,7 +41,7 @@ run_configure_for_cpu_build # build_libtensorflow_tarball in ../builds/libtensorflow.sh # cannot be used on Windows since it relies on pkg_tar rules. # So we do something special here -bazel build -c opt --copt=/arch:AVX \ +bazel --output_user_root=${TMPDIR} build -c opt --copt=/arch:AVX \ tensorflow:libtensorflow.so \ tensorflow/tools/lib_package:clicenses_generate \ tensorflow/java:libtensorflow_jni.so \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl deleted file mode 100644 index 6796ad70e5d22ca683343680b142081d8d58a9e4..0000000000000000000000000000000000000000 --- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl +++ /dev/null @@ -1,83 +0,0 @@ -FROM tensorflow/tensorflow:latest-devel - -LABEL maintainer="Clayne Robison" - -# These arguments are parameterized. Use --build-args to override. -ARG TF_BRANCH=r1.9 -ARG WHL_DIR=/whl - -RUN apt-get update && apt-get install -y --no-install-recommends \ - golang \ - vim \ - emacs \ - && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -RUN pip --no-cache-dir install --upgrade \ - pip setuptools - -RUN pip --no-cache-dir install wheel - -# Download and build TensorFlow. -WORKDIR / -RUN rm -rf tensorflow && \ - git clone https://github.com/tensorflow/tensorflow.git && \ - cd tensorflow && \ - git checkout ${TF_BRANCH} -WORKDIR /tensorflow - -# Configure the build for CPU with MKL by accepting default build options and -# setting library locations -ENV CI_BUILD_PYTHON=python \ - LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \ - PYTHON_BIN_PATH=/usr/bin/python \ - PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \ - CC_OPT_FLAGS='-march=native' \ - TF_NEED_JEMALLOC=0 \ - TF_NEED_GCP=1 \ - TF_NEED_CUDA=0 \ - TF_NEED_HDFS=0 \ - TF_NEED_S3=1 \ - TF_NEED_OPENCL=0 \ - TF_NEED_GDR=0 \ - TF_ENABLE_XLA=0 \ - TF_NEED_VERBS=0 \ - TF_NEED_MPI=0 -RUN ./configure - -# Build and Install TensorFlow. -# The 'mkl' option builds with Intel(R) Math Kernel Library (MKL), which detects -# the platform it is currently running on and takes appropriately optimized -# paths. The -march=native option is for code that is not in MKL, and assumes -# this container will be run on the same architecture on which it is built. -RUN LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \ - bazel build --config=mkl \ - --config="opt" \ - --copt="-march=broadwell" \ - --copt="-O3" \ - //tensorflow/tools/pip_package:build_pip_package && \ - mkdir ${WHL_DIR} && \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package ${WHL_DIR} - -# Clean up Bazel cache when done, but leave the whl. -# This will upgrade the default Tensorflow version with the Intel MKL version -RUN pip --no-cache-dir install --upgrade ${WHL_DIR}/tensorflow-*.whl && \ - rm -rf /root/.cache - -WORKDIR /root - -#add welcome message with instructions - -RUN echo '[ ! -z "$TERM" -a -r /etc/motd ] && cat /etc/issue && cat /etc/motd' \ - >> /etc/bash.bashrc \ - ; echo "\ -||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\ -| \n\ -| Docker container running Ubuntu \n\ -| with TensorFlow ${TF_BRANCH} optimized for CPU \n\ -| with Intel(R) MKL \n\ -| \n\ -||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\ -\n "\ - > /etc/motd diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl new file mode 100755 index 0000000000000000000000000000000000000000..de44ba21734acbd78c4cad6dc2aca7672c9d574b --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.devel-mkl @@ -0,0 +1,128 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Clayne Robison " + +# These parameters can be overridden by parameterized_docker_build.sh +ARG TF_BUILD_VERSION=r1.9 +ARG PYTHON="python" +ARG PYTHON3_DEV="" +ARG WHL_DIR="/tmp/pip" +ARG PIP="pip" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python-dev \ + ${PYTHON3_DEV} \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \ + ${PYTHON} get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + && \ + ${PYTHON} -m ipykernel.kernelspec + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# Set up Bazel. + +# Running bazel inside a `docker build` command causes trouble, cf: +# https://github.com/bazelbuild/bazel/issues/134 +# The easiest solution is to set up a bazelrc file forcing --batch. +RUN echo "startup --batch" >>/etc/bazel.bazelrc +# Similarly, we need to workaround sandboxing issues: +# https://github.com/bazelbuild/bazel/issues/418 +RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ + >>/etc/bazel.bazelrc +# Install the most recent bazel release. +ENV BAZEL_VERSION 0.11.0 +WORKDIR / +RUN mkdir /bazel && \ + cd /bazel && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ + chmod +x bazel-*.sh && \ + ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + cd / && \ + rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh + +# Download and build TensorFlow. +WORKDIR /tensorflow + +# Download and build TensorFlow. +# Enable checking out both tags and branches +RUN export TAG_PREFIX="v" && \ + echo ${TF_BUILD_VERSION} | grep -q ^${TAG_PREFIX}; \ + if [ $? -eq 0 ]; then \ + git clone --depth=1 https://github.com/tensorflow/tensorflow.git . && \ + git fetch --tags && \ + git checkout ${TF_BUILD_VERSION}; \ + else \ + git clone --depth=1 --branch=${TF_BUILD_VERSION} https://github.com/tensorflow/tensorflow.git . ; \ + fi + +RUN yes "" | ${PYTHON} configure.py + +ENV CI_BUILD_PYTHON ${PYTHON} + +# Set bazel build parameters in .bazelrc in parameterized_docker_build.sh +# Use --copt=-march values to get optimized builds appropriate for the hardware +# platform of your choice. +# For ivy-bridge or sandy-bridge +# --copt=-march="avx" \ +# For haswell, broadwell, or skylake +# --copt=-march="avx2" \ +COPY .bazelrc /root/.bazelrc + +RUN tensorflow/tools/ci_build/builds/configured CPU \ + bazel --bazelrc=/root/.bazelrc build -c opt \ + tensorflow/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package "${WHL_DIR}" && \ + ${PIP} --no-cache-dir install --upgrade "${WHL_DIR}"/tensorflow-*.whl && \ + rm -rf /root/.cache +# Clean up Bazel cache when done. + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR /root diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl new file mode 100755 index 0000000000000000000000000000000000000000..139395d49102fe2de3e241936095613da3f21bf8 --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.mkl @@ -0,0 +1,75 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Clayne Robison " + +# This parameter MUST be set by parameterized_docker_build.sh +ARG TF_WHL_URL + +# Optional parameters +ARG TF_BUILD_VERSION=r1.9 +ARG PYTHON="python" +ARG PYTHON_DEV="python-dev" +ARG PIP="pip" + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python \ + ${PYTHON_DEV} \ + rsync \ + software-properties-common \ + unzip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ + python get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + matplotlib \ + numpy \ + pandas \ + scipy \ + sklearn \ + && \ + python -m ipykernel.kernelspec + +COPY ${TF_WHL_URL} / +RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \ + rm -rf /${TF_WHL_URL} + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Copy sample notebooks. +COPY notebooks /notebooks + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR "/notebooks" + +CMD ["/run_jupyter.sh", "--allow-root"] diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh index 05de25f2cb11d76f223a31bc12329e6ab7368e8a..4681c5fd61158e0be998d72bb4329f204808eda7 100755 --- a/tensorflow/tools/docker/parameterized_docker_build.sh +++ b/tensorflow/tools/docker/parameterized_docker_build.sh @@ -19,8 +19,8 @@ # parameterized_docker_build.sh # # The script obeys the following environment variables: -# TF_DOCKER_BUILD_TYPE: (CPU | GPU) -# CPU or GPU image +# TF_DOCKER_BUILD_TYPE: (CPU | GPU | MKL) +# CPU, GPU, or MKL image # # TF_DOCKER_BUILD_IS_DEVEL: (NO | YES) # Is this developer image @@ -87,6 +87,15 @@ # TF_DOCKER_BUILD_OPTIONS # (Optional) # Specifies the desired build options. Defaults to OPT. +# +# TF_DOCKER_BUILD_ARGS +# (Optional) +# A list (array) of docker build args. Will be passed to docker build +# command as list of --build-arg parameters. +# +# TF_BAZEL_BUILD_OPTIONS +# (Optional) +# Bazel compiler flags to be passed to the bazelrc file # Script directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -116,6 +125,8 @@ echo " TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}" echo " TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}" echo " TF_DOCKER_BUILD_PORT=${TF_DOCKER_BUILD_PORT}" echo " TF_DOCKER_BUILD_PUSH_CMD=${TF_DOCKER_BUILD_PUSH_CMD}" +echo " TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]:-()}" +echo " TF_BAZEL_BUILD_OPTIONS=${TF_BAZEL_BUILD_OPTIONS}" CONTAINER_PORT=${TF_DOCKER_BUILD_PORT:-8888} @@ -149,6 +160,15 @@ fi if [[ ${TF_DOCKER_BUILD_TYPE} == "cpu" ]]; then DOCKER_BINARY="docker" +elif [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + DOCKER_BINARY="docker" + FINAL_TAG="${FINAL_TAG}-mkl" + if [[ ${ORIG_DOCKERFILE} == *"."* ]]; then + # There is already a dot in the tag, use "-" + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}-mkl" + else + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}.mkl" + fi elif [[ ${TF_DOCKER_BUILD_TYPE} == "gpu" ]]; then DOCKER_BINARY="nvidia-docker" @@ -203,6 +223,10 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then export TF_BUILD_OPTIONS=${TF_DOCKER_BUILD_OPTIONS} export TF_BUILD_IS_PIP="PIP" + if [[ "${TF_DOCKER_BUILD_TYPE}" == "mkl" ]]; then + die "FAIL: Non-development MKL builds require a pre-built pip whl." + fi + if [[ "${TF_DOCKER_BUILD_TYPE}" == "gpu" ]]; then export TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\ "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2" @@ -255,25 +279,39 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Use string replacement to put the correct file name into the Dockerfile PIP_WHL=$(basename "${PIP_WHL}") - # Modify the non-devel Dockerfile to point to the correct pip whl file - # location - sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}" ) + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the non-devel Dockerfile to point to the correct pip whl file + # location + sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ "/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\ "COPY ${PIP_WHL} /\n"\ "RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \ - > "${DOCKERFILE}" + > "${DOCKERFILE}" + fi echo "Using local pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" echo - else echo "Downloading pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" - echo - - # Modify the non-devel Dockerfile to point to the correct pip whl URL. - sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + pushd "${TMP_DIR}/" + curl -O ${TF_DOCKER_BUILD_CENTRAL_PIP} + popd + PIP_WHL_PATH=`find ${TMP_DIR} -name "*.whl"` + PIP_WHL=$(basename "${PIP_WHL_PATH}") + echo "PIP_WHL= ${PIP_WHL}" + echo + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the non-devel Dockerfile to point to the correct pip whl URL. + sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ "/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\ "RUN pip --no-cache-dir install ${TF_DOCKER_BUILD_CENTRAL_PIP}" "${ORIG_DOCKERFILE}" \ - > "${DOCKERFILE}" + > "${DOCKERFILE}" + fi fi echo "Modified Dockerfile at: ${DOCKERFILE}" @@ -281,36 +319,66 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Modify python/pip version if necessary. if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ - sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \ - sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" - then - echo "Modified Dockerfile for python version "\ -"${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}") + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON_DEV=python3-dev") + TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" else - die "FAILED to modify ${DOCKERFILE} for python3" + if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ + sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \ + sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ + sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile for python version "\ + "${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi fi -else +else # TF_DOCKER_BUILD_IS_DEVEL == 'yes' DOCKERFILE="${TMP_DIR}/Dockerfile" - # Modify the devel Dockerfile to specify the git branch - sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \ - "${ORIG_DOCKERFILE}" > "${DOCKERFILE}" + # Set up Dockerfile ARGS for mkl build + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + if [[ -z "${TF_BAZEL_BUILD_OPTIONS// }" ]]; then + TF_BAZEL_BUILD_OPTIONS=("--config=mkl --copt=-mavx --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0") + else + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" + fi + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_BUILD_VERSION=${TF_DOCKER_BUILD_DEVEL_BRANCH}") + echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}" + + # Pass the build options to bazel using the user-specific .bazelrc file + echo "build ${TF_BAZEL_BUILD_OPTIONS}" >> ${TMP_DIR}/.bazelrc + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the devel Dockerfile to specify the git branch + sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \ + "${ORIG_DOCKERFILE}" > "${DOCKERFILE}" + fi # Modify python/pip version if necessary. if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ - sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ - sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \ - sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ - sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" - then - echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}") + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON3_DEV=python3-dev") + TF_DOCKER_BUILD_ARGS+=("--build-arg WHL_DIR=/tmp/pip3") + TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" else - die "FAILED to modify ${DOCKERFILE} for python3" + if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ + sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ + sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \ + sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ + sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \ + sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi fi fi @@ -319,8 +387,11 @@ fi # Intermediate image name with tag IMG="${USER}/tensorflow:${FINAL_TAG}" echo "Building docker image with image name and tag: ${IMG}" +echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}" +CMD="${DOCKER_BINARY} build ${TF_DOCKER_BUILD_ARGS[@]} --no-cache --pull -t ${IMG} -f ${DOCKERFILE} ${TMP_DIR}" +echo "CMD=${CMD}" +${CMD} -"${DOCKER_BINARY}" build --no-cache --pull -t "${IMG}" -f "${DOCKERFILE}" "${TMP_DIR}" if [[ $? == "0" ]]; then echo "${DOCKER_BINARY} build of ${IMG} succeeded" else @@ -340,7 +411,7 @@ fi DOCKER_RUN_LOG="${TMP_DIR}/docker_run.log" echo "" echo "Running docker container from image ${IMG}..." -echo " (Log file is at: ${DOCKER_RUN_LOG}" +echo " Log file is at: ${DOCKER_RUN_LOG}" echo "" if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then @@ -386,7 +457,6 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Stop the running docker container sleep 1 "${DOCKER_BINARY}" stop --time=0 ${CONTAINER_ID} - fi diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 58b5ef8345c9de83e2d50cd01fe11e11f51fe298..2403e2d966929b86976bf6a31f8144d9b4f58bc6 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -37,7 +37,11 @@ py_library( srcs = ["parser.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = ["@astor_archive//:astor"], + deps = [ + "//tensorflow/python:platform", + "//tensorflow/python:util", + "@astor_archive//:astor", + ], ) py_test( @@ -92,6 +96,7 @@ py_binary( deps = [ ":generate_lib", "//tensorflow:tensorflow_py", + "//tensorflow/python:util", "//tensorflow/python/debug:debug_py", ], ) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 853ec6194f8327f13b3eb6ac7792511c9c4494cd..e7634cd5dcf19d5f21b0bd42b282dfe928659a52 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -21,6 +21,7 @@ from __future__ import print_function import argparse import fnmatch import os +import shutil import six @@ -81,12 +82,8 @@ def write_docs(output_dir, raise ValueError("'output_dir' must be an absolute path.\n" " output_dir='%s'" % output_dir) - try: - if not os.path.exists(output_dir): - os.makedirs(output_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (output_dir, e)) - raise + if not os.path.exists(output_dir): + os.makedirs(output_dir) # These dictionaries are used for table-of-contents generation below # They will contain, after the for-loop below:: @@ -129,8 +126,6 @@ def write_docs(output_dir, module_children.setdefault(subname, []).append(full_name) break - print('Writing docs for %s (%r).' % (full_name, py_object)) - # Generate docs for `py_object`, resolving references. page_info = parser.docs_for_object(full_name, py_object, parser_config) @@ -151,10 +146,9 @@ def write_docs(output_dir, text = text.encode('utf-8') with open(path, 'wb') as f: f.write(text) - except OSError as e: - print('Cannot write documentation for %s to %s: %s' % (full_name, - directory, e)) - raise + except OSError: + raise OSError( + 'Cannot write documentation for %s to %s' % (full_name, directory)) if yaml_toc: # Generate table of contents @@ -394,16 +388,40 @@ def _build_guide_index(guide_src_dir): class _UpdateTags(py_guide_parser.PyGuideParser): - """Rewrites a Python guide so that each section has an explicit tag.""" + """Rewrites a Python guide so that each section has an explicit id tag. + + "section" here refers to blocks delimited by second level headings. + """ def process_section(self, line_number, section_title, tag): self.replace_line(line_number, '

%s

' % (tag, section_title)) +def update_id_tags_inplace(src_dir): + """Set explicit ids on all second-level headings to ensure back-links work. + + Args: + src_dir: The directory of md-files to convert (inplace). + """ + tag_updater = _UpdateTags() + + for dirpath, _, filenames in os.walk(src_dir): + for base_name in filenames: + if not base_name.endswith('.md'): + continue + full_path = os.path.join(src_dir, dirpath, base_name) + + # Tag updater loads the file, makes the replacements, and returns the + # modified file contents + content = tag_updater.process(full_path) + with open(full_path, 'w') as f: + f.write(content) + + EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt']) -def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): +def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): """Fix @{} references in all files under `src_dir` matching `file_pattern`. A matching directory structure, with the modified files is @@ -424,7 +442,6 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): using fnmatch. Non-matching files are copied unchanged. """ # Iterate through all the source files and process them. - tag_updater = _UpdateTags() for dirpath, _, filenames in os.walk(src_dir): # How to get from `dirpath` to api_docs/python/ relative_path_to_root = os.path.relpath( @@ -433,41 +450,32 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): # Make the directory under output_dir. new_dir = os.path.join(output_dir, os.path.relpath(path=dirpath, start=src_dir)) - try: - if not os.path.exists(new_dir): - os.makedirs(new_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (new_dir, e)) - raise + if not os.path.exists(new_dir): + os.makedirs(new_dir) for base_name in filenames: if base_name in EXCLUDED: - print('Skipping excluded file %s...' % base_name) continue full_in_path = os.path.join(dirpath, base_name) + # Set the `current_doc_full_name` so bad files can be reported on errors. reference_resolver.current_doc_full_name = full_in_path suffix = os.path.relpath(path=full_in_path, start=src_dir) full_out_path = os.path.join(output_dir, suffix) + # Copy files that do not match the file_pattern, unmodified. if not fnmatch.fnmatch(base_name, file_pattern): - print('Copying un-matched file %s...' % suffix) - open(full_out_path, 'wb').write(open(full_in_path, 'rb').read()) + shutil.copyfile(full_in_path, full_out_path) continue - if dirpath.endswith('/api_guides/python'): - print('Processing Python guide %s...' % base_name) - content = tag_updater.process(full_in_path) - else: - print('Processing doc %s...' % suffix) - content = open(full_in_path, 'rb').read().decode('utf-8') + + with open(full_in_path, 'rb') as f: + content = f.read().decode('utf-8') content = reference_resolver.replace_references(content, relative_path_to_root) with open(full_out_path, 'wb') as f: f.write(content.encode('utf-8')) - print('Done.') - class DocGenerator(object): """Main entry point for generating docs.""" @@ -554,15 +562,43 @@ class DocGenerator(object): self._do_not_descend_map) def build(self, flags): - """Actually build the docs.""" + """Build all the docs. + + This produces two outputs + + python api docs: + + * generated from modules set with `set_py_modules`. + * written to '{FLAGS.output_dir}/api_docs/python/' + + non-api docs: + + * Everything in '{FLAGS.src_dir}' is copied to '{FLAGS.output_dir}'. + * '@{}' references in '.md' files are replaced with links. + * '.md' files under 'api_guides/python' have explicit ids set for their + second level headings. + + Args: + flags: + * src_dir: Where to fetch the non-api-docs. + * base_dir: Base of the docs directory (Used to build correct + relative links). + * output_dir: Where to write the resulting docs. + + Returns: + The number of errors encountered while processing. + """ + # Extract the python api from the _py_modules doc_index = build_doc_index(flags.src_dir) visitor = self.run_extraction() reference_resolver = self.make_reference_resolver(visitor, doc_index) + # Build the guide_index for the api_docs back links. root_title = getattr(flags, 'root_title', 'TensorFlow') guide_index = _build_guide_index( os.path.join(flags.src_dir, 'api_guides/python')) + # Write the api docs. parser_config = self.make_parser_config(visitor, reference_resolver, guide_index, flags.base_dir) output_dir = os.path.join(flags.output_dir, 'api_docs/python') @@ -573,8 +609,16 @@ class DocGenerator(object): yaml_toc=self.yaml_toc, root_title=root_title, search_hints=getattr(flags, 'search_hints', True)) - _other_docs(flags.src_dir, flags.output_dir, reference_resolver) + # Replace all the @{} references in files under `FLAGS.src_dir` + replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md') + # Fix the tags in the guide dir. + guide_dir = os.path.join(flags.output_dir, 'api_guides/python') + if os.path.exists(guide_dir): + update_id_tags_inplace(guide_dir) + + # Report all errors found by the reference resolver, and return the error + # code. parser_config.reference_resolver.log_errors() return parser_config.reference_resolver.num_errors() diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index ea6d28a02b1f3c07fe8783fd59e345dade1fc804..7a6f9fd9f799db5a14015d77e5297955c76a51cd 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -51,7 +51,9 @@ class DummyVisitor(object): class GenerateTest(googletest.TestCase): - def test_write(self): + def get_test_objects(self): + # These are all mutable objects, so rebuild them for each test. + # Don't cache the objects. module = sys.modules[__name__] index = { @@ -98,6 +100,11 @@ class GenerateTest(googletest.TestCase): guide_index={}, base_dir=base_dir) + return reference_resolver, parser_config + + def test_write(self): + _, parser_config = self.get_test_objects() + output_dir = googletest.GetTempDir() generate_lib.write_docs(output_dir, parser_config, yaml_toc=True) @@ -127,6 +134,107 @@ class GenerateTest(googletest.TestCase): os.path.exists( os.path.join(output_dir, 'tf/TestModule/test_function.md'))) + def test_update_id_tags_inplace(self): + test_dir = googletest.GetTempDir() + test_sub_dir = os.path.join(test_dir, 'a/b') + os.makedirs(test_sub_dir) + + test_path1 = os.path.join(test_dir, 'file1.md') + test_path2 = os.path.join(test_sub_dir, 'file2.md') + test_path3 = os.path.join(test_sub_dir, 'file3.notmd') + + with open(test_path1, 'w') as f: + f.write('## abc&123') + + with open(test_path2, 'w') as f: + f.write('# A Level 1 Heading\n') + f.write('## A Level 2 Heading') + + with open(test_path3, 'w') as f: + f.write("## don\'t change this") + + generate_lib.update_id_tags_inplace(test_dir) + + with open(test_path1) as f: + content = f.read() + + self.assertEqual(content, '

abc&123

') + + with open(test_path2) as f: + content = f.read() + + self.assertEqual( + content, '# A Level 1 Heading\n' + '

A Level 2 Heading

') + + with open(test_path3) as f: + content = f.read() + + self.assertEqual(content, "## don\'t change this") + + def test_replace_refes(self): + test_dir = googletest.GetTempDir() + test_in_dir = os.path.join(test_dir, 'in') + test_in_dir_a = os.path.join(test_dir, 'in/a') + test_in_dir_b = os.path.join(test_dir, 'in/b') + os.makedirs(test_in_dir) + os.makedirs(test_in_dir_a) + os.makedirs(test_in_dir_b) + + test_out_dir = os.path.join(test_dir, 'out') + os.makedirs(test_out_dir) + + test_path1 = os.path.join(test_in_dir_a, 'file1.md') + test_path2 = os.path.join(test_in_dir_b, 'file2.md') + test_path3 = os.path.join(test_in_dir_b, 'file3.notmd') + test_path4 = os.path.join(test_in_dir_b, 'OWNERS') + + with open(test_path1, 'w') as f: + f.write('Use `tf.test_function` to test things.') + + with open(test_path2, 'w') as f: + f.write('Use @{tf.TestModule.TestClass.ChildClass} to test things.\n' + "`tf.whatever` doesn't exist") + + with open(test_path3, 'w') as f: + file3_content = ( + 'Not a .md file. Should be copied unchanged:' + '@{tf.TestModule.TestClass.ChildClass}, `tf.test_function`') + f.write(file3_content) + + with open(test_path4, 'w') as f: + f.write('') + + reference_resolver, _ = self.get_test_objects() + generate_lib.replace_refs(test_in_dir, test_out_dir, reference_resolver, + '*.md') + + with open(os.path.join(test_out_dir, 'a/file1.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use ' + 'tf.test_function to test things.') + + with open(os.path.join(test_out_dir, 'b/file2.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use ' + '' + 'tf.TestModule.TestClass.ChildClass ' + 'to test things.\n' + '`tf.whatever` doesn\'t exist') + + with open(os.path.join(test_out_dir, 'b/file3.notmd')) as f: + content = f.read() + self.assertEqual(content, file3_content) + + with self.assertRaises(IOError): + # This should fail. The OWNERS file should not be copied + with open(os.path.join(test_out_dir, 'b/OWNERS')) as f: + content = f.read() + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 50c90527413d0904c78dab199a68678f6cc91845..ffb93027ed48dd2106c702758917c0846f20cb1c 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -25,12 +25,12 @@ import itertools import json import os import re -import sys import astor import six from google.protobuf.message import Message as ProtoMessage +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect @@ -53,7 +53,7 @@ class _Errors(object): template = 'ERROR:\n output file name: %s\n %s\n\n' for full_name, message in self._errors: - print(template % (full_name, message), file=sys.stderr) + logging.warn(template, full_name, message) def append(self, full_name, message): """Add an error to the collection. @@ -761,8 +761,9 @@ def _generate_signature(func, reverse_index): lookup_text = public_name + default_text[len(internal_name):] break if default_text is lookup_text: - print('WARNING: Using default arg, failed lookup: %s, repr: %r' % - (default_text, default)) + logging.warn( + 'WARNING: Using default arg, failed lookup: %s, repr: %r', + default_text, default) else: default_text = lookup_text else: @@ -1165,7 +1166,7 @@ class _ClassPageInfo(object): if short_name in [ '__class__', '__base__', '__weakref__', '__doc__', '__module__', '__dict__', '__abstractmethods__', '__slots__', '__getnewargs__', - '__str__', '__repr__', '__hash__' + '__str__', '__repr__', '__hash__', '__reduce__' ]: continue @@ -1213,8 +1214,6 @@ class _ClassPageInfo(object): if not child_doc.brief.strip() and short_name in [ '__del__', '__copy__' ]: - print('Skipping %s, defined in %s, no docstring.' % (child_name, - defining_class)) continue try: @@ -1371,7 +1370,8 @@ class _ModulePageInfo(object): for name in member_names: if name in ['__builtins__', '__doc__', '__file__', - '__name__', '__path__', '__package__']: + '__name__', '__path__', '__package__', + '__cached__', '__loader__', '__spec__']: continue member_full_name = self.full_name + '.' + name if self.full_name else name diff --git a/tensorflow/tools/docs/py_guide_parser.py b/tensorflow/tools/docs/py_guide_parser.py index 328f42d18f1efb0fd82725a4683abad2df0d5a19..b00694dc40322161f180410630bb4dcfd8c2fb18 100644 --- a/tensorflow/tools/docs/py_guide_parser.py +++ b/tensorflow/tools/docs/py_guide_parser.py @@ -44,7 +44,8 @@ class PyGuideParser(object): def process(self, full_path): """Read and process the file at `full_path`.""" - md_string = open(full_path, 'rb').read().decode('utf-8') + with open(full_path, 'rb') as f: + md_string = f.read().decode('utf-8') self._lines = md_string.split('\n') seen = set() diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 77f83b77a0214110e520c85d15ffa38bce65955f..173f418dc8d998bc51d208a04c8671bacf364cdc 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -115,6 +115,7 @@ genrule( "//third_party/fft2d:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", @@ -130,7 +131,7 @@ genrule( "@highwayhash//:LICENSE", "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", - "@libxsmm_archive//:LICENSE", + "@libxsmm_archive//:LICENSE.md", "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", @@ -156,6 +157,7 @@ genrule( "//third_party/fft2d:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", @@ -168,7 +170,7 @@ genrule( "@highwayhash//:LICENSE", "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", - "@libxsmm_archive//:LICENSE", + "@libxsmm_archive//:LICENSE.md", "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index d8356cec4739da61c7c428153642a3eb71d8b935..c9d53f46c3cff9eceb6eb03a872d05e8afd06047 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -57,16 +57,18 @@ COMMON_PIP_DEPS = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/autograph:autograph", "//tensorflow/contrib/autograph/converters:converters", - "//tensorflow/contrib/autograph/converters:test_lib", + "//tensorflow/contrib/autograph/core:core", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/impl:impl", - "//tensorflow/contrib/autograph/operators:operators", "//tensorflow/contrib/autograph/lang:lang", + "//tensorflow/contrib/autograph/operators:operators", "//tensorflow/contrib/autograph/pyct:pyct", "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis", + "//tensorflow/contrib/autograph/pyct/common_transformers:common_transformers", "//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/constrained_optimization:constrained_optimization_pip", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/contrib/data/python/ops:contrib_op_loader", "//tensorflow/contrib/eager/python/examples:examples_pip", "//tensorflow/contrib/eager/python:evaluator", @@ -92,6 +94,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/contrib/timeseries:timeseries_pip", "//tensorflow/contrib/tpu", "//tensorflow/examples/tutorials/mnist:package", + "//tensorflow/python:cond_v2", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:meta_graph_testdata", "//tensorflow/python:spectral_ops_test_util", @@ -127,6 +130,8 @@ filegroup( "@astor_archive//:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googleapis_googleapis//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_google_absl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", @@ -144,7 +149,7 @@ filegroup( "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", "@kafka//:LICENSE", - "@libxsmm_archive//:LICENSE", + "@libxsmm_archive//:LICENSE.md", "@lmdb//:LICENSE", "@local_config_nccl//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 55cd4f37c682d95461850d312bb48353efd8194f..c630ca04b885d35da6550d4e5f3e6912b5fd7a00 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -53,7 +53,7 @@ REQUIRED_PACKAGES = [ 'gast >= 0.2.0', 'numpy >= 1.13.3', 'six >= 1.10.0', - 'protobuf >= 3.4.0', + 'protobuf >= 3.6.0', 'setuptools <= 39.1.0', 'tensorboard >= 1.8.0, < 1.9.0', 'termcolor >= 1.1.0', @@ -170,8 +170,9 @@ class InstallHeaders(Command): # symlink within the directory hierarchy. # NOTE(keveman): Figure out how to customize bdist_wheel package so # we can do the symlink. - if 'external/eigen_archive/' in install_dir: - extra_dir = install_dir.replace('external/eigen_archive', '') + if 'tensorflow/include/external/eigen_archive/' in install_dir: + extra_dir = install_dir.replace( + 'tensorflow/include/external/eigen_archive', '') if not os.path.exists(extra_dir): self.mkpath(extra_dir) self.copy_file(header, extra_dir) @@ -204,13 +205,12 @@ def find_files(pattern, root): yield os.path.join(dirpath, filename) -matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x] - so_lib_paths = [ i for i in os.listdir('.') if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*') ] +matches = [] for path in so_lib_paths: matches.extend( ['../' + x for x in find_files('*', path) if '.py' not in x] @@ -225,7 +225,7 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*.h', 'tensorflow/stream_executor')) + list(find_files('*.h', 'google/protobuf_archive/src')) + list(find_files('*', 'third_party/eigen3')) + - list(find_files('*', 'external/eigen_archive'))) + list(find_files('*', 'tensorflow/include/external/eigen_archive'))) setup( name=project_name, diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 09f7a9b7dd6342f4a05a6f3e68e91de909f716a2..2fe0b6f0723af8e629ecedb4280e67de2e5b0e2f 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -107,11 +107,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "eigen_archive", urls = [ - "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/267806ed9b4f.tar.gz", - "https://bitbucket.org/eigen/eigen/get/267806ed9b4f.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/e5e305a158a0.tar.gz", + "https://bitbucket.org/eigen/eigen/get/e5e305a158a0.tar.gz", ], - sha256 = "ade57357093463cab9e4e51cd5749c81483a75451b1471a3ebc73f9c1d14043b", - strip_prefix = "eigen-eigen-267806ed9b4f", + sha256 = "8bbe676d69e7f59070c83a949454b8b6344034e0ebbf686b337528e5dc04c7de", + strip_prefix = "eigen-eigen-e5e305a158a0", build_file = clean_dep("//third_party:eigen.BUILD"), ) @@ -131,11 +131,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "libxsmm_archive", urls = [ - "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", - "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", + "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.9.tar.gz", ], - sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", - strip_prefix = "libxsmm-1.8.1", + sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa", + strip_prefix = "libxsmm-1.9", build_file = clean_dep("//third_party:libxsmm.BUILD"), ) @@ -155,12 +155,33 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "com_googlesource_code_re2", urls = [ - "https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", - "https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", + "https://mirror.bazel.build/github.com/google/re2/archive/2018-04-01.tar.gz", + "https://github.com/google/re2/archive/2018-04-01.tar.gz", ], - sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b", - strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857", + sha256 = "2f945446b71336e7f5a2bcace1abcf0b23fbba368266c6a1be33de3de3b3c912", + strip_prefix = "re2-2018-04-01", + ) + + tf_http_archive( + name = "com_github_googlecloudplatform_google_cloud_cpp", + urls = [ + "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f9ff105957965bcf87f7cb9a93e951c3d08d1734.tar.gz", + "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f9ff105957965bcf87f7cb9a93e951c3d08d1734.tar.gz", + ], + sha256 = "edb347aae9869ffdcf8df6288335bcc535fec46da946b385c16968e96a74b208", + strip_prefix = "google-cloud-cpp-f9ff105957965bcf87f7cb9a93e951c3d08d1734", + ) + + tf_http_archive( + name = "com_github_googleapis_googleapis", + urls = [ + "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + ], + sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378", + strip_prefix="googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb", + build_file = clean_dep("//third_party:googleapis.BUILD"), ) tf_http_archive( @@ -299,11 +320,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "absl_py", urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", - "https://github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", + "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", + "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", ], - sha256 = "c30b48e0d2580ef1412e55c5c0e1dab8db2ee4ab56e2075eccff29c90c7c7059", - strip_prefix = "abseil-py-ea8c4d2ddbf3fba610c4d613260561699b776db8", + sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c", + strip_prefix = "abseil-py-pypi-v0.2.2", ) tf_http_archive( @@ -331,11 +352,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "protobuf_archive", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", ) # We need to import the protobuf library under the names com_google_protobuf @@ -344,31 +365,31 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "com_google_protobuf", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", ) tf_http_archive( name = "com_google_protobuf_cc", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", ) tf_http_archive( name = "nsync", urls = [ - "https://mirror.bazel.build/github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", - "https://github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz", + "https://github.com/google/nsync/archive/1.20.0.tar.gz", ], - sha256 = "6284454c5cd8b1dae2eeb8cf5eb63004de930b5427ed5f6b1aa793513df6b361", - strip_prefix = "nsync-0559ce013feac8db639ee1bf776aca0325d28777", + sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd", + strip_prefix = "nsync-1.20.0", ) tf_http_archive( @@ -393,12 +414,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "pcre", - sha256 = "ccdf7e788769838f8285b3ee672ed573358202305ee361cfec7a4a4fb005bbc7", + sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5", urls = [ - "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", - "http://ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", + "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz", + "http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz", ], - strip_prefix = "pcre-8.39", + strip_prefix = "pcre-8.42", build_file = clean_dep("//third_party:pcre.BUILD"), ) @@ -416,12 +437,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "curl", - sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", + sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5", urls = [ - "https://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz", - "https://curl.haxx.se/download/curl-7.49.1.tar.gz", + "https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz", + "https://curl.haxx.se/download/curl-7.60.0.tar.gz", ], - strip_prefix = "curl-7.49.1", + strip_prefix = "curl-7.60.0", build_file = clean_dep("//third_party:curl.BUILD"), ) @@ -452,33 +473,33 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/45a02a4f8474b4b8c5cc106b5cecb06cf6e1b3c6.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/45a02a4f8474b4b8c5cc106b5cecb06cf6e1b3c6.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/fe1e7736763a8577ac081eca525e05d3b52de414.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/fe1e7736763a8577ac081eca525e05d3b52de414.tar.gz", ], - sha256 = "056f7316a354d1f95e013176bd9b8be74e8f4d47fb0d908e0e742613187dbd59", - strip_prefix = "llvm-45a02a4f8474b4b8c5cc106b5cecb06cf6e1b3c6", - build_file = clean_dep("//third_party/llvm:llvm.BUILD"), + sha256 = "77b9a98d3c0be94561fed32f44a7a8c78421e01a74bad009964d8bbaf066ed6c", + strip_prefix = "llvm-fe1e7736763a8577ac081eca525e05d3b52de414", + build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) tf_http_archive( name = "lmdb", urls = [ - "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", - "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", + "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz", + "https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz", ], - sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", - strip_prefix = "lmdb-LMDB_0.9.19/libraries/liblmdb", + sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28", + strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb", build_file = clean_dep("//third_party:lmdb.BUILD"), ) tf_http_archive( name = "jsoncpp_git", urls = [ - "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", - "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", + "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", + "https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", ], - sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", - strip_prefix = "jsoncpp-11086dd6a7eba04289944367ca82cea71299ed70", + sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6", + strip_prefix = "jsoncpp-1.8.4", build_file = clean_dep("//third_party:jsoncpp.BUILD"), ) @@ -538,11 +559,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "kafka", urls = [ - "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", - "https://github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", + "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz", + "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz", ], - sha256 = "dd035d57c8f19b0b612dd6eefe6e5eebad76f506e302cccb7c2066f25a83585e", - strip_prefix = "librdkafka-0.11.1", + sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c", + strip_prefix = "librdkafka-0.11.4", build_file = clean_dep("//third_party:kafka/BUILD"), patch_file = clean_dep("//third_party/kafka:config.patch"), ) @@ -695,11 +716,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-971a68110e4fc1bace10fcb6deeb189e7e1a34ce", - sha256 = "874088d2ee0d9f8524191f77209556415f03dd44e156276edf19e5b90ceb5f55", + strip_prefix = "flatbuffers-1.9.0", + sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3", urls = [ - "https://mirror.bazel.build/github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", - "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", + "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz", + "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz", ], build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"), ) @@ -765,6 +786,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "ovic", ) + tf_http_archive( + name = "build_bazel_rules_android", + sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip", + "https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip", + ], + strip_prefix = "rules_android-0.1.1", + ) + ############################################################################## # BIND DEFINITIONS # diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl index a203245005cc215250380239e5ac4d1dbc209d97..a014a806a69ecf9d7e43c51daf3672fc5750e706 100644 --- a/third_party/clang_toolchain/download_clang.bzl +++ b/third_party/clang_toolchain/download_clang.bzl @@ -35,18 +35,18 @@ def download_clang(repo_ctx, out_folder): # Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release # can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py - CLANG_REVISION = '332838' + CLANG_REVISION = '335091' CLANG_SUB_REVISION = 1 package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION) checksums = { 'Linux_x64': - 'b9ef55de7500778f366039dbe62d1632074a3ef3673022eabf4e59d405730968', + '17002b75293fccfdd175eacdc9ee47d97b58d7e98fef343384fbbef1b68ce99f', 'Mac': - '30d808512763c98cecf15f7bb654d845de3e8d065a95f5c5b6b3459254cc98d6', + '9351e46d28315daaa06a1eb55bd0370ed4aaeb693a2a3e82e48d2737d7723468', 'Win': - '277e799a190b22727c26b09986c0cedbd667a189f425318f421addf6a21ca4bd', + 'e78a1e469224d6f6751b4df4374bf58893ac03900ec924e4c8264888ba4aeb1e', } platform_folder = _get_platform_folder(repo_ctx.os.name) diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index 4def6f94892329e0d8b594b824babd60ea259351..1638b7216162abca208267ff804c6d92231081f6 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -7,6 +7,7 @@ exports_files(["COPYING"]) CURL_WIN_COPTS = [ "/Iexternal/curl/lib", + "/DBUILDING_LIBCURL", "/DHAVE_CONFIG_H", "/DCURL_DISABLE_FTP", "/DCURL_DISABLE_NTLM", @@ -49,6 +50,8 @@ cc_library( "lib/curl_addrinfo.c", "lib/curl_addrinfo.h", "lib/curl_base64.h", + "lib/curl_ctype.c", + "lib/curl_ctype.h", "lib/curl_des.h", "lib/curl_endian.h", "lib/curl_fnmatch.c", @@ -75,6 +78,7 @@ cc_library( "lib/curl_sec.h", "lib/curl_setup.h", "lib/curl_setup_once.h", + "lib/curl_sha256.h", "lib/curl_sspi.c", "lib/curl_sspi.h", "lib/curl_threads.c", @@ -134,6 +138,8 @@ cc_library( "lib/md5.c", "lib/memdebug.c", "lib/memdebug.h", + "lib/mime.c", + "lib/mime.h", "lib/mprintf.c", "lib/multi.c", "lib/multihandle.h", @@ -153,8 +159,8 @@ cc_library( "lib/pop3.h", "lib/progress.c", "lib/progress.h", - "lib/rawstr.c", - "lib/rawstr.h", + "lib/rand.c", + "lib/rand.h", "lib/rtsp.c", "lib/rtsp.h", "lib/security.c", @@ -162,8 +168,11 @@ cc_library( "lib/select.h", "lib/sendf.c", "lib/sendf.h", + "lib/setopt.c", + "lib/setopt.h", "lib/setup-os400.h", "lib/setup-vms.h", + "lib/sha256.c", "lib/share.c", "lib/share.h", "lib/sigpipe.h", @@ -179,10 +188,10 @@ cc_library( "lib/splay.c", "lib/splay.h", "lib/ssh.h", + "lib/strcase.c", + "lib/strcase.h", "lib/strdup.c", "lib/strdup.h", - "lib/strequal.c", - "lib/strequal.h", "lib/strerror.c", "lib/strerror.h", "lib/strtok.c", @@ -241,13 +250,12 @@ cc_library( }), hdrs = [ "include/curl/curl.h", - "include/curl/curlbuild.h", - "include/curl/curlrules.h", "include/curl/curlver.h", "include/curl/easy.h", "include/curl/mprintf.h", "include/curl/multi.h", "include/curl/stdcheaders.h", + "include/curl/system.h", "include/curl/typecheck-gcc.h", ], copts = select({ @@ -256,6 +264,7 @@ cc_library( "//conditions:default": [ "-Iexternal/curl/lib", "-D_GNU_SOURCE", + "-DBUILDING_LIBCURL", "-DHAVE_CONFIG_H", "-DCURL_DISABLE_FTP", "-DCURL_DISABLE_NTLM", # turning it off in configure is not enough @@ -676,6 +685,7 @@ genrule( "# define SIZEOF_INT 4", "# define SIZEOF_LONG 8", "# define SIZEOF_OFF_T 8", + "# define SIZEOF_CURL_OFF_T 8", "# define SIZEOF_SHORT 2", "# define SIZEOF_SIZE_T 8", "# define SIZEOF_TIME_T 8", diff --git a/third_party/eigen.BUILD b/third_party/eigen.BUILD index e54c1a4501d46b6b68a9b8fcc9ce0b1af0535ef4..759f8a9be92e14537d334c3ec37f036d369d8796 100644 --- a/third_party/eigen.BUILD +++ b/third_party/eigen.BUILD @@ -69,3 +69,9 @@ cc_library( includes = ["."], visibility = ["//visibility:public"], ) + +filegroup( + name = "eigen_header_files", + srcs = EIGEN_MPL2_HEADER_FILES, + visibility = ["//visibility:public"], +) diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD index f661093bc9f68b845f3000b0a931c66773fb3339..9d9c27b180fb670cccb27dc7d6b8445927bfabce 100644 --- a/third_party/eigen3/BUILD +++ b/third_party/eigen3/BUILD @@ -17,21 +17,23 @@ load("//tensorflow:tensorflow.bzl", "if_mkl") # INTEL_MKL end load("//tensorflow:tensorflow.bzl", "if_mkl") +EIGEN3_THIRD_PARTY_HEADERS = [ + "Eigen/Core", + "Eigen/LU", + "Eigen/Cholesky", + "Eigen/Eigenvalues", + "Eigen/QR", + "Eigen/SVD", + "unsupported/Eigen/MatrixFunctions", + "unsupported/Eigen/SpecialFunctions", + "unsupported/Eigen/CXX11/ThreadPool", + "unsupported/Eigen/CXX11/Tensor", + "unsupported/Eigen/CXX11/FixedPoint", +] + glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + cc_library( name = "eigen3", - hdrs = glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + [ - "Eigen/Core", - "Eigen/LU", - "Eigen/Cholesky", - "Eigen/Eigenvalues", - "Eigen/QR", - "Eigen/SVD", - "unsupported/Eigen/MatrixFunctions", - "unsupported/Eigen/SpecialFunctions", - "unsupported/Eigen/CXX11/ThreadPool", - "unsupported/Eigen/CXX11/Tensor", - "unsupported/Eigen/CXX11/FixedPoint", - ], + hdrs = EIGEN3_THIRD_PARTY_HEADERS, includes = if_mkl(["./mkl_include"]), visibility = ["//visibility:public"], deps = [ @@ -48,3 +50,35 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +filegroup( + name = "eigen_third_party_header_files", + srcs = EIGEN3_THIRD_PARTY_HEADERS, + visibility = ["//visibility:public"], +) + +genrule( + name = "install_eigen_headers", + srcs = [ + "@eigen_archive//:eigen_header_files", + ":eigen_third_party_header_files", + ], + outs = ["include"], + cmd = """ + mkdir $@ + for f in $(locations @eigen_archive//:eigen_header_files) ; do + d="$${f%/*}" + d="$${d#*external/eigen_archive/}" + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + + for f in $(locations :eigen_third_party_header_files) ; do + d="$${f%/*}" + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + """ +) diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md index fbb1fde837b92bc521698d0a517a946da0438dbc..e2fd8009a052d7cbfd01b48af7da6b891ad08c74 100644 --- a/third_party/examples/eager/spinn/README.md +++ b/third_party/examples/eager/spinn/README.md @@ -22,7 +22,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth - [`data.py`](../../../../tensorflow/contrib/eager/python/examples/spinn/data.py): Pipeline for loading and preprocessing the [SNLI](https://nlp.stanford.edu/projects/snli/) data and [GloVe](https://nlp.stanford.edu/projects/glove/) word embedding, written - using the [`tf.data`](https://www.tensorflow.org/programmers_guide/datasets) + using the [`tf.data`](https://www.tensorflow.org/guide/datasets) API. - [`spinn.py`](./spinn.py): Model definition and training routines. This example illustrates how one might perform the following actions with diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD index 824c97be60e7ef148a363b964ed330ba3c5fcb0c..639dff2cd01056cf70e727b39c0a0c537c763c9e 100644 --- a/third_party/flatbuffers/flatbuffers.BUILD +++ b/third_party/flatbuffers/flatbuffers.BUILD @@ -98,6 +98,8 @@ cc_binary( "grpc/src/compiler/cpp_generator.h", "grpc/src/compiler/go_generator.cc", "grpc/src/compiler/go_generator.h", + "grpc/src/compiler/java_generator.cc", + "grpc/src/compiler/java_generator.h", "grpc/src/compiler/schema_interface.h", "src/flatc_main.cpp", "src/idl_gen_cpp.cpp", diff --git a/third_party/googleapis.BUILD b/third_party/googleapis.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..95e999af1886576317aa59d133e8d5c88ba368d3 --- /dev/null +++ b/third_party/googleapis.BUILD @@ -0,0 +1,45 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) +licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + +load("@protobuf_archive//:protobuf.bzl", "cc_proto_library") + +cc_proto_library( + name = "bigtable_protos", + srcs = [ + "google/bigtable/admin/v2/bigtable_instance_admin.proto", + "google/bigtable/admin/v2/bigtable_table_admin.proto", + "google/bigtable/admin/v2/common.proto", + "google/bigtable/admin/v2/instance.proto", + "google/bigtable/admin/v2/table.proto", + "google/bigtable/v2/bigtable.proto", + "google/bigtable/v2/data.proto", + "google/iam/v1/iam_policy.proto", + "google/iam/v1/policy.proto", + "google/longrunning/operations.proto", + "google/rpc/status.proto", + "google/rpc/error_details.proto", + "google/api/annotations.proto", + "google/api/auth.proto", + "google/api/http.proto", + ], + include = ".", + protoc = "@protobuf_archive//:protoc", + default_runtime = "@protobuf_archive//:protobuf", + deps = ["@protobuf_archive//:cc_wkt_protos"], + use_grpc_plugin = True, +) diff --git a/third_party/jsoncpp.BUILD b/third_party/jsoncpp.BUILD index 65f98410b289a7e324c9ed89e33de1c6010fa21a..cf3cba05556a0bb22a632475c6ab810b8230f355 100644 --- a/third_party/jsoncpp.BUILD +++ b/third_party/jsoncpp.BUILD @@ -6,7 +6,6 @@ cc_library( name = "jsoncpp", srcs = [ "include/json/assertions.h", - "src/lib_json/json_batchallocator.h", "src/lib_json/json_reader.cpp", "src/lib_json/json_tool.h", "src/lib_json/json_value.cpp", @@ -20,9 +19,13 @@ cc_library( "include/json/json.h", "include/json/reader.h", "include/json/value.h", + "include/json/version.h", "include/json/writer.h", ], - copts = ["-DJSON_USE_EXCEPTION=0"], + copts = [ + "-DJSON_USE_EXCEPTION=0", + "-DJSON_HAS_INT64", + ], includes = ["include"], visibility = ["//visibility:public"], deps = [":private"], diff --git a/third_party/kafka/BUILD b/third_party/kafka/BUILD index a839ca717e695f35fac684b510f0a022010e0710..75792b0d87366c304ca29f95f943114ee482dfcd 100644 --- a/third_party/kafka/BUILD +++ b/third_party/kafka/BUILD @@ -60,6 +60,8 @@ cc_library( "src/rdkafka_event.h", "src/rdkafka_feature.c", "src/rdkafka_feature.h", + "src/rdkafka_header.c", + "src/rdkafka_header.h", "src/rdkafka_int.h", "src/rdkafka_interceptor.c", "src/rdkafka_interceptor.h", @@ -93,7 +95,6 @@ cc_library( "src/rdkafka_sasl_int.h", "src/rdkafka_sasl_plain.c", "src/rdkafka_subscription.c", - "src/rdkafka_subscription.h", "src/rdkafka_timer.c", "src/rdkafka_timer.h", "src/rdkafka_topic.c", @@ -105,6 +106,8 @@ cc_library( "src/rdlist.h", "src/rdlog.c", "src/rdlog.h", + "src/rdmurmur2.c", + "src/rdmurmur2.h", "src/rdports.c", "src/rdports.h", "src/rdposix.h", diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD index 78ed1f4e168891367ddc2249da726a6ef16dd5d5..ee49d281abcd54b566edde119f4a5b3e6b07d2a3 100644 --- a/third_party/libxsmm.BUILD +++ b/third_party/libxsmm.BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # BSD 3-clause -exports_files(["LICENSE"]) +exports_files(["LICENSE.md"]) # Arguments to ./scripts/libxsmm_interface.py, see that file for detailed description. # precision: SP & DP diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.autogenerated.BUILD similarity index 89% rename from third_party/llvm/llvm.BUILD rename to third_party/llvm/llvm.autogenerated.BUILD index e1c22c815196cc9be0af763ae6400ecb40555e4e..d931932d9d517cb5f0638a87569b697e35e158f6 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.autogenerated.BUILD @@ -8,10 +8,13 @@ exports_files(["LICENSE.TXT"]) load( "@org_tensorflow//third_party/llvm:llvm.bzl", + "LLVM_COPTS", + "LLVM_DEFINES", + "LLVM_LINKOPTS", "cmake_var_string", "expand_cmake_vars", "gentbl", - "llvm_target_cmake_vars", + "llvm_all_cmake_vars", ) load( "@org_tensorflow//third_party:common.bzl", @@ -39,147 +42,25 @@ llvm_target_asm_printers = llvm_targets llvm_target_disassemblers = llvm_targets -# TODO(phawkins): the set of CMake variables was hardcoded for expediency. -# However, we should really detect many of these via configure-time tests. - -# The set of CMake variables common to all targets. -cmake_vars = { - # Headers - "HAVE_DIRENT_H": 1, - "HAVE_DLFCN_H": 1, - "HAVE_ERRNO_H": 1, - "HAVE_EXECINFO_H": 1, - "HAVE_FCNTL_H": 1, - "HAVE_INTTYPES_H": 1, - "HAVE_PTHREAD_H": 1, - "HAVE_SIGNAL_H": 1, - "HAVE_STDINT_H": 1, - "HAVE_SYS_IOCTL_H": 1, - "HAVE_SYS_MMAN_H": 1, - "HAVE_SYS_PARAM_H": 1, - "HAVE_SYS_RESOURCE_H": 1, - "HAVE_SYS_STAT_H": 1, - "HAVE_SYS_TIME_H": 1, - "HAVE_SYS_TYPES_H": 1, - "HAVE_TERMIOS_H": 1, - "HAVE_UNISTD_H": 1, - "HAVE_ZLIB_H": 1, - - # Features - "HAVE_BACKTRACE": 1, - "BACKTRACE_HEADER": "execinfo.h", - "HAVE_DLOPEN": 1, - "HAVE_FUTIMES": 1, - "HAVE_GETCWD": 1, - "HAVE_GETPAGESIZE": 1, - "HAVE_GETRLIMIT": 1, - "HAVE_GETRUSAGE": 1, - "HAVE_GETTIMEOFDAY": 1, - "HAVE_INT64_T": 1, - "HAVE_ISATTY": 1, - "HAVE_LIBEDIT": 1, - "HAVE_LIBPTHREAD": 1, - "HAVE_LIBZ": 1, - "HAVE_MKDTEMP": 1, - "HAVE_MKSTEMP": 1, - "HAVE_MKTEMP": 1, - "HAVE_PREAD": 1, - "HAVE_PTHREAD_GETSPECIFIC": 1, - "HAVE_PTHREAD_MUTEX_LOCK": 1, - "HAVE_PTHREAD_RWLOCK_INIT": 1, - "HAVE_REALPATH": 1, - "HAVE_SBRK": 1, - "HAVE_SETENV": 1, - "HAVE_SETRLIMIT": 1, - "HAVE_SIGALTSTACK": 1, - "HAVE_STRERROR": 1, - "HAVE_STRERROR_R": 1, - "HAVE_STRTOLL": 1, - "HAVE_SYSCONF": 1, - "HAVE_UINT64_T": 1, - "HAVE__UNWIND_BACKTRACE": 1, - - # LLVM features - "ENABLE_BACKTRACES": 1, - "LLVM_BINDIR": "/dev/null", - "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, - "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, - "LLVM_ENABLE_THREADS": 1, - "LLVM_ENABLE_ZLIB": 1, - "LLVM_HAS_ATOMICS": 1, - "LLVM_INCLUDEDIR": "/dev/null", - "LLVM_INFODIR": "/dev/null", - "LLVM_MANDIR": "/dev/null", - "LLVM_NATIVE_TARGET": 1, - "LLVM_NATIVE_TARGETINFO": 1, - "LLVM_NATIVE_TARGETMC": 1, - "LLVM_NATIVE_ASMPRINTER": 1, - "LLVM_NATIVE_ASMPARSER": 1, - "LLVM_NATIVE_DISASSEMBLER": 1, - "LLVM_ON_UNIX": 1, - "LLVM_PREFIX": "/dev/null", - "LLVM_VERSION_MAJOR": 0, - "LLVM_VERSION_MINOR": 0, - "LLVM_VERSION_PATCH": 0, - "LTDL_SHLIB_EXT": ".so", - "PACKAGE_NAME": "llvm", - "PACKAGE_STRING": "llvm tensorflow-trunk", - "PACKAGE_VERSION": "tensorflow-trunk", - "RETSIGTYPE": "void", -} - -# CMake variables specific to the Linux platform -linux_cmake_vars = { - "HAVE_MALLOC_H": 1, - "HAVE_LINK_H": 1, - "HAVE_MALLINFO": 1, - "HAVE_FUTIMENS": 1, -} - -# CMake variables specific to the Darwin (Mac OS X) platform. -darwin_cmake_vars = { - "HAVE_MALLOC_MALLOC_H": 1, -} - -# Select a set of CMake variables based on the platform. -# TODO(phawkins): use a better method to select the right host triple, rather -# than hardcoding x86_64. -all_cmake_vars = select({ - "@org_tensorflow//tensorflow:darwin": cmake_var_string( - cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") + - darwin_cmake_vars, - ), - "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string( - cmake_vars + - llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") + - linux_cmake_vars, - ), - "//conditions:default": cmake_var_string( - cmake_vars + - llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") + - linux_cmake_vars, - ), -}) - # Performs CMake variable substitutions on configuration header files. expand_cmake_vars( name = "config_gen", src = "include/llvm/Config/config.h.cmake", - cmake_vars = all_cmake_vars, + cmake_vars = llvm_all_cmake_vars, dst = "include/llvm/Config/config.h", ) expand_cmake_vars( name = "llvm_config_gen", src = "include/llvm/Config/llvm-config.h.cmake", - cmake_vars = all_cmake_vars, + cmake_vars = llvm_all_cmake_vars, dst = "include/llvm/Config/llvm-config.h", ) expand_cmake_vars( name = "abi_breaking_gen", src = "include/llvm/Config/abi-breaking.h.cmake", - cmake_vars = all_cmake_vars, + cmake_vars = llvm_all_cmake_vars, dst = "include/llvm/Config/abi-breaking.h", ) @@ -240,14 +121,7 @@ cc_library( "include/llvm/Config/config.h", "include/llvm/Config/llvm-config.h", ], - defines = [ - "LLVM_ENABLE_STATS", - "__STDC_LIMIT_MACROS", - "__STDC_CONSTANT_MACROS", - "__STDC_FORMAT_MACROS", - "_DEBUG", - "LLVM_BUILD_GLOBAL_ISEL", - ], + defines = LLVM_DEFINES, includes = ["include"], ) @@ -262,17 +136,6 @@ genrule( ) # Rules that apply the LLVM tblgen tool. -gentbl( - name = "intrinsics_gen", - tbl_outs = [("-gen-intrinsic", "include/llvm/IR/Intrinsics.inc")], - tblgen = ":llvm-tblgen", - td_file = "include/llvm/IR/Intrinsics.td", - td_srcs = glob([ - "include/llvm/CodeGen/*.td", - "include/llvm/IR/Intrinsics*.td", - ]), -) - gentbl( name = "attributes_gen", tbl_outs = [("-gen-attrs", "include/llvm/IR/Attributes.inc")], @@ -292,6 +155,42 @@ gentbl( ], ) +gentbl( + name = "instcombine_transforms_gen", + tbl_outs = [( + "-gen-searchable-tables", + "lib/Transforms/InstCombine/InstCombineTables.inc", + )], + tblgen = ":llvm-tblgen", + td_file = "lib/Transforms/InstCombine/InstCombineTables.td", + td_srcs = glob([ + "include/llvm/CodeGen/*.td", + "include/llvm/IR/Intrinsics*.td", + ]) + ["include/llvm/TableGen/SearchableTable.td"], +) + +gentbl( + name = "intrinsic_enums_gen", + tbl_outs = [("-gen-intrinsic-enums", "include/llvm/IR/IntrinsicEnums.inc")], + tblgen = ":llvm-tblgen", + td_file = "include/llvm/IR/Intrinsics.td", + td_srcs = glob([ + "include/llvm/CodeGen/*.td", + "include/llvm/IR/Intrinsics*.td", + ]), +) + +gentbl( + name = "intrinsics_impl_gen", + tbl_outs = [("-gen-intrinsic-impl", "include/llvm/IR/IntrinsicImpl.inc")], + tblgen = ":llvm-tblgen", + td_file = "include/llvm/IR/Intrinsics.td", + td_srcs = glob([ + "include/llvm/CodeGen/*.td", + "include/llvm/IR/Intrinsics*.td", + ]), +) + # Binary targets used by Tensorflow. cc_binary( name = "llvm-tblgen", @@ -299,11 +198,7 @@ cc_binary( "utils/TableGen/*.cpp", "utils/TableGen/*.h", ]), - linkopts = [ - "-lm", - "-ldl", - "-lpthread", - ], + linkopts = LLVM_LINKOPTS, stamp = 0, deps = [ ":config", @@ -319,11 +214,7 @@ cc_binary( "utils/FileCheck/*.cpp", "utils/FileCheck/*.h", ]), - linkopts = [ - "-ldl", - "-lm", - "-lpthread", - ], + linkopts = LLVM_LINKOPTS, stamp = 0, deps = [":support"], ) @@ -494,7 +385,8 @@ cc_library( "include/llvm/Target/AArch64/AsmParser/*.inc", "lib/Target/AArch64/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_desc", ":aarch64_info", @@ -519,7 +411,8 @@ cc_library( "include/llvm/Target/AArch64/InstPrinter/*.inc", "lib/Target/AArch64/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_target_gen", ":aarch64_utils", @@ -542,7 +435,8 @@ cc_library( "include/llvm/Target/AArch64/*.inc", "lib/Target/AArch64/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_asm_printer", ":aarch64_desc", @@ -575,14 +469,16 @@ cc_library( "include/llvm/Target/AArch64/MCTargetDesc/*.inc", "lib/Target/AArch64/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_asm_printer", ":aarch64_info", ":aarch64_target_gen", ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":support", ], @@ -601,7 +497,8 @@ cc_library( "include/llvm/Target/AArch64/Disassembler/*.inc", "lib/Target/AArch64/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_desc", ":aarch64_info", @@ -629,7 +526,8 @@ cc_library( "lib/Target/AArch64/AArch64*.h", "lib/Target/AArch64/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":code_gen", ":config", @@ -652,7 +550,8 @@ cc_library( "include/llvm/Target/AArch64/Utils/*.inc", "lib/Target/AArch64/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AArch64"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"], + defines = LLVM_DEFINES, deps = [ ":aarch64_target_gen", ":config", @@ -674,6 +573,8 @@ cc_library( "include/llvm/Transforms/AggressiveInstCombine/*.def", "include/llvm/Transforms/AggressiveInstCombine/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -698,6 +599,8 @@ cc_library( "include/llvm/Analysis/*.def", "include/llvm/Analysis/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -721,7 +624,8 @@ cc_library( "include/llvm/Target/AMDGPU/MCTargetDesc/*.inc", "lib/Target/AMDGPU/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_asm_printer", ":amdgpu_info", @@ -746,7 +650,8 @@ cc_library( "include/llvm/Target/AMDGPU/Disassembler/*.inc", "lib/Target/AMDGPU/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_desc", ":amdgpu_info", @@ -771,7 +676,8 @@ cc_library( "include/llvm/Target/AMDGPU/TargetInfo/*.inc", "lib/Target/AMDGPU/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_target_gen", ":config", @@ -793,7 +699,8 @@ cc_library( "include/llvm/Target/AMDGPU/Utils/*.inc", "lib/Target/AMDGPU/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_target_gen", ":config", @@ -816,7 +723,8 @@ cc_library( "include/llvm/Target/AMDGPU/AsmParser/*.inc", "lib/Target/AMDGPU/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_desc", ":amdgpu_info", @@ -841,7 +749,8 @@ cc_library( "include/llvm/Target/AMDGPU/InstPrinter/*.inc", "lib/Target/AMDGPU/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_utils", ":config", @@ -863,7 +772,8 @@ cc_library( "include/llvm/Target/AMDGPU/*.inc", "lib/Target/AMDGPU/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"], + defines = LLVM_DEFINES, deps = [ ":amdgpu_asm_printer", ":amdgpu_desc", @@ -899,7 +809,8 @@ cc_library( "include/llvm/Target/ARM/AsmParser/*.inc", "lib/Target/ARM/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_desc", ":arm_info", @@ -925,7 +836,8 @@ cc_library( "lib/Target/ARM/*.h", "lib/Target/ARM/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_info", ":arm_target_gen", @@ -949,7 +861,8 @@ cc_library( "include/llvm/Target/ARM/*.inc", "lib/Target/ARM/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":arm_asm_printer", @@ -984,14 +897,16 @@ cc_library( "include/llvm/Target/ARM/MCTargetDesc/*.inc", "lib/Target/ARM/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_asm_printer", ":arm_info", ":arm_target_gen", ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":mc_disassembler", ":support", @@ -1011,7 +926,8 @@ cc_library( "include/llvm/Target/ARM/Disassembler/*.inc", "lib/Target/ARM/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_desc", ":arm_info", @@ -1036,7 +952,8 @@ cc_library( "include/llvm/Target/ARM/TargetInfo/*.inc", "lib/Target/ARM/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_target_gen", ":config", @@ -1059,7 +976,8 @@ cc_library( "include/llvm/Target/ARM/Utils/*.inc", "lib/Target/ARM/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/ARM"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"], + defines = LLVM_DEFINES, deps = [ ":arm_target_gen", ":config", @@ -1081,6 +999,8 @@ cc_library( "include/llvm/AsmParser/*.def", "include/llvm/AsmParser/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -1103,6 +1023,8 @@ cc_library( "include/llvm/CodeGen/AsmPrinter/*.inc", "lib/CodeGen/AsmPrinter/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":binary_format", @@ -1133,6 +1055,8 @@ cc_library( "include/llvm/BinaryFormat/ELFRelocs/*.def", "include/llvm/BinaryFormat/WasmRelocs/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":support", @@ -1153,6 +1077,8 @@ cc_library( "include/llvm/Bitcode/Reader/*.inc", "include/llvm/Bitcode/BitstreamReader.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1176,6 +1102,8 @@ cc_library( "include/llvm/Bitcode/BitcodeWriterPass.h", "include/llvm/Bitcode/BitstreamWriter.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -1200,6 +1128,8 @@ cc_library( "include/llvm/CodeGen/*.inc", "include/llvm/CodeGen/**/*.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":bit_reader", @@ -1237,12 +1167,15 @@ cc_library( "include/llvm/*.h", "include/llvm/Analysis/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":attributes_compat_gen", ":attributes_gen", ":binary_format", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":support", ], ) @@ -1260,6 +1193,8 @@ cc_library( "include/llvm/DebugInfo/CodeView/*.def", "include/llvm/DebugInfo/CodeView/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -1281,6 +1216,8 @@ cc_library( "include/llvm/DebugInfo/MSF/*.def", "include/llvm/DebugInfo/MSF/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":support", @@ -1300,6 +1237,8 @@ cc_library( "include/llvm/Demangle/*.def", "include/llvm/Demangle/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [":config"], ) @@ -1316,6 +1255,8 @@ cc_library( "include/llvm/ExecutionEngine/*.def", "include/llvm/ExecutionEngine/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1340,6 +1281,8 @@ cc_library( "include/llvm/CodeGen/GlobalISel/*.def", "include/llvm/CodeGen/GlobalISel/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":code_gen", @@ -1369,6 +1312,8 @@ cc_library( "include/llvm/Transforms/InstrProfiling.h", "include/llvm/Transforms/PGOInstrumentation.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -1393,10 +1338,13 @@ cc_library( "include/llvm/Transforms/InstCombine/*.def", "include/llvm/Transforms/InstCombine/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", ":core", + ":instcombine_transforms_gen", ":support", ":transform_utils", ], @@ -1418,6 +1366,8 @@ cc_library( "include/llvm/Transforms/IPO/*.def", "include/llvm/Transforms/IPO/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":aggressive_inst_combine", ":analysis", @@ -1451,6 +1401,8 @@ cc_library( "include/llvm/IRReader/*.def", "include/llvm/IRReader/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":asm_parser", ":bit_reader", @@ -1473,6 +1425,8 @@ cc_library( "include/llvm/Linker/*.def", "include/llvm/Linker/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1494,6 +1448,8 @@ cc_library( "include/llvm/MC/*.def", "include/llvm/MC/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":config", @@ -1515,6 +1471,8 @@ cc_library( "include/llvm/MC/MCDisassembler/*.def", "include/llvm/MC/MCDisassembler/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1535,6 +1493,8 @@ cc_library( "include/llvm/MC/MCParser/*.def", "include/llvm/MC/MCParser/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1555,7 +1515,8 @@ cc_library( "include/llvm/Target/NVPTX/InstPrinter/*.inc", "lib/Target/NVPTX/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ "nvptx_target_gen", ":attributes_gen", @@ -1579,7 +1540,8 @@ cc_library( "include/llvm/Target/NVPTX/*.inc", "lib/Target/NVPTX/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":asm_printer", @@ -1613,7 +1575,8 @@ cc_library( "include/llvm/Target/NVPTX/MCTargetDesc/*.inc", "lib/Target/NVPTX/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ "nvptx_target_gen", ":config", @@ -1639,7 +1602,8 @@ cc_library( "lib/Target/NVPTX/NVPTX.h", "lib/Target/NVPTX/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/NVPTX"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"], + defines = LLVM_DEFINES, deps = [ "nvptx_target_gen", ":attributes_gen", @@ -1663,6 +1627,8 @@ cc_library( "include/llvm/Object/*.def", "include/llvm/Object/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":binary_format", ":bit_reader", @@ -1688,6 +1654,8 @@ cc_library( "include/llvm/Transforms/ObjCARC/*.def", "include/llvm/Transforms/ObjCARC/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -1710,13 +1678,17 @@ cc_library( "include/llvm/ExecutionEngine/Orc/*.def", "include/llvm/ExecutionEngine/Orc/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", ":execution_engine", + ":mc", ":object", ":runtime_dyld", ":support", + ":target", ":transform_utils", ], ) @@ -1734,7 +1706,8 @@ cc_library( "include/llvm/Target/PowerPC/AsmParser/*.inc", "lib/Target/PowerPC/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1758,11 +1731,13 @@ cc_library( "include/llvm/Target/PowerPC/InstPrinter/*.inc", "lib/Target/PowerPC/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":powerpc_info", ":powerpc_target_gen", @@ -1783,7 +1758,8 @@ cc_library( "include/llvm/Target/PowerPC/*.inc", "lib/Target/PowerPC/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":asm_printer", @@ -1815,11 +1791,13 @@ cc_library( "include/llvm/Target/PowerPC/MCTargetDesc/*.inc", "lib/Target/PowerPC/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":attributes_gen", ":config", - ":intrinsics_gen", + ":intrinsic_enums_gen", + ":intrinsics_impl_gen", ":mc", ":powerpc_asm_printer", ":powerpc_info", @@ -1841,7 +1819,8 @@ cc_library( "include/llvm/Target/PowerPC/Disassembler/*.inc", "lib/Target/PowerPC/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc_disassembler", @@ -1865,12 +1844,12 @@ cc_library( "lib/Target/PowerPC/PPC*.h", "lib/Target/PowerPC/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/PowerPC"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"], + defines = LLVM_DEFINES, deps = [ ":attributes_gen", ":config", ":core", - ":intrinsics_gen", ":powerpc_target_gen", ":support", ":target", @@ -1890,6 +1869,8 @@ cc_library( "include/llvm/ProfileData/*.def", "include/llvm/ProfileData/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":core", @@ -1918,6 +1899,8 @@ cc_library( "include/llvm/ExecutionEngine/RTDyldMemoryManager.h", "include/llvm/ExecutionEngine/RuntimeDyld*.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -1945,6 +1928,8 @@ cc_library( "include/llvm/Transforms/IPO.h", "include/llvm/Transforms/IPO/SCCP.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":aggressive_inst_combine", ":analysis", @@ -1970,6 +1955,8 @@ cc_library( "include/llvm/CodeGen/SelectionDAG/*.def", "include/llvm/CodeGen/SelectionDAG/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":code_gen", @@ -2007,6 +1994,8 @@ cc_library( "include/llvm/BinaryFormat/MachO.def", "include/llvm/Support/VCSRevision.h", ], + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":demangle", @@ -2029,6 +2018,8 @@ cc_library( "include/llvm/TableGen/*.inc", "include/llvm/Target/*.def", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2054,6 +2045,8 @@ cc_library( "include/llvm/CodeGen/*.def", "include/llvm/CodeGen/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -2078,6 +2071,8 @@ cc_library( "include/llvm/Transforms/Utils/*.def", "include/llvm/Transforms/Utils/*.inc", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -2101,6 +2096,8 @@ cc_library( "include/llvm/Transforms/Vectorize/*.inc", "include/llvm/Transforms/Vectorize.h", ]), + copts = LLVM_COPTS, + defines = LLVM_DEFINES, deps = [ ":analysis", ":config", @@ -2124,7 +2121,8 @@ cc_library( "include/llvm/Target/X86/AsmParser/*.inc", "lib/Target/X86/AsmParser/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2149,7 +2147,8 @@ cc_library( "include/llvm/Target/X86/InstPrinter/*.inc", "lib/Target/X86/InstPrinter/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2173,7 +2172,8 @@ cc_library( "include/llvm/Target/X86/*.inc", "lib/Target/X86/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":analysis", ":asm_printer", @@ -2206,7 +2206,8 @@ cc_library( "include/llvm/Target/X86/MCTargetDesc/*.inc", "lib/Target/X86/MCTargetDesc/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2231,7 +2232,8 @@ cc_library( "include/llvm/Target/X86/Disassembler/*.inc", "lib/Target/X86/Disassembler/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc_disassembler", @@ -2254,7 +2256,8 @@ cc_library( "include/llvm/Target/X86/TargetInfo/*.inc", "lib/Target/X86/TargetInfo/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":config", ":mc", @@ -2276,7 +2279,8 @@ cc_library( "include/llvm/Target/X86/Utils/*.inc", "lib/Target/X86/Utils/*.h", ]), - copts = ["-Iexternal/llvm/lib/Target/X86"], + copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"], + defines = LLVM_DEFINES, deps = [ ":code_gen", ":config", diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 0efcf319bd99be79263a1b9cd23544523a4c8076..2e809e5f147d9e2b359dbf8fcc57575572bc64cd 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -105,3 +105,136 @@ def expand_cmake_vars(name, src, dst, cmake_vars): "< $< > $@") ) +# TODO(phawkins): the set of CMake variables was hardcoded for expediency. +# However, we should really detect many of these via configure-time tests. + +# The set of CMake variables common to all targets. +cmake_vars = { + # Headers + "HAVE_DIRENT_H": 1, + "HAVE_DLFCN_H": 1, + "HAVE_ERRNO_H": 1, + "HAVE_EXECINFO_H": 1, + "HAVE_FCNTL_H": 1, + "HAVE_INTTYPES_H": 1, + "HAVE_PTHREAD_H": 1, + "HAVE_SIGNAL_H": 1, + "HAVE_STDINT_H": 1, + "HAVE_SYS_IOCTL_H": 1, + "HAVE_SYS_MMAN_H": 1, + "HAVE_SYS_PARAM_H": 1, + "HAVE_SYS_RESOURCE_H": 1, + "HAVE_SYS_STAT_H": 1, + "HAVE_SYS_TIME_H": 1, + "HAVE_SYS_TYPES_H": 1, + "HAVE_TERMIOS_H": 1, + "HAVE_UNISTD_H": 1, + "HAVE_ZLIB_H": 1, + + # Features + "HAVE_BACKTRACE": 1, + "BACKTRACE_HEADER": "execinfo.h", + "HAVE_DLOPEN": 1, + "HAVE_FUTIMES": 1, + "HAVE_GETCWD": 1, + "HAVE_GETPAGESIZE": 1, + "HAVE_GETRLIMIT": 1, + "HAVE_GETRUSAGE": 1, + "HAVE_GETTIMEOFDAY": 1, + "HAVE_INT64_T": 1, + "HAVE_ISATTY": 1, + "HAVE_LIBEDIT": 1, + "HAVE_LIBPTHREAD": 1, + "HAVE_LIBZ": 1, + "HAVE_MKDTEMP": 1, + "HAVE_MKSTEMP": 1, + "HAVE_MKTEMP": 1, + "HAVE_PREAD": 1, + "HAVE_PTHREAD_GETSPECIFIC": 1, + "HAVE_PTHREAD_MUTEX_LOCK": 1, + "HAVE_PTHREAD_RWLOCK_INIT": 1, + "HAVE_REALPATH": 1, + "HAVE_SBRK": 1, + "HAVE_SETENV": 1, + "HAVE_SETRLIMIT": 1, + "HAVE_SIGALTSTACK": 1, + "HAVE_STRERROR": 1, + "HAVE_STRERROR_R": 1, + "HAVE_STRTOLL": 1, + "HAVE_SYSCONF": 1, + "HAVE_UINT64_T": 1, + "HAVE__UNWIND_BACKTRACE": 1, + + # LLVM features + "ENABLE_BACKTRACES": 1, + "LLVM_BINDIR": "/dev/null", + "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, + "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, + "LLVM_ENABLE_THREADS": 1, + "LLVM_ENABLE_ZLIB": 1, + "LLVM_HAS_ATOMICS": 1, + "LLVM_INCLUDEDIR": "/dev/null", + "LLVM_INFODIR": "/dev/null", + "LLVM_MANDIR": "/dev/null", + "LLVM_NATIVE_TARGET": 1, + "LLVM_NATIVE_TARGETINFO": 1, + "LLVM_NATIVE_TARGETMC": 1, + "LLVM_NATIVE_ASMPRINTER": 1, + "LLVM_NATIVE_ASMPARSER": 1, + "LLVM_NATIVE_DISASSEMBLER": 1, + "LLVM_ON_UNIX": 1, + "LLVM_PREFIX": "/dev/null", + "LLVM_VERSION_MAJOR": 0, + "LLVM_VERSION_MINOR": 0, + "LLVM_VERSION_PATCH": 0, + "LTDL_SHLIB_EXT": ".so", + "PACKAGE_NAME": "llvm", + "PACKAGE_STRING": "llvm tensorflow-trunk", + "PACKAGE_VERSION": "tensorflow-trunk", + "RETSIGTYPE": "void", +} + +# CMake variables specific to the Linux platform +linux_cmake_vars = { + "HAVE_MALLOC_H": 1, + "HAVE_LINK_H": 1, + "HAVE_MALLINFO": 1, + "HAVE_FUTIMENS": 1, +} + +# CMake variables specific to the Darwin (Mac OS X) platform. +darwin_cmake_vars = { + "HAVE_MALLOC_MALLOC_H": 1, +} + +# Select a set of CMake variables based on the platform. +# TODO(phawkins): use a better method to select the right host triple, rather +# than hardcoding x86_64. +llvm_all_cmake_vars = select({ + "@org_tensorflow//tensorflow:darwin": cmake_var_string( + cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") + + darwin_cmake_vars), + "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string( + cmake_vars + + llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") + + linux_cmake_vars, + ), + "//conditions:default": cmake_var_string( + cmake_vars + + llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") + + linux_cmake_vars), + +}) + +LLVM_LINKOPTS = ["-ldl", "-lm", "-lpthread"] + +LLVM_DEFINES = [ + "LLVM_ENABLE_STATS", + "__STDC_LIMIT_MACROS", + "__STDC_CONSTANT_MACROS", + "__STDC_FORMAT_MACROS", + "_DEBUG", + "LLVM_BUILD_GLOBAL_ISEL", +] + +LLVM_COPTS = [] diff --git a/third_party/repo.bzl b/third_party/repo.bzl index cb67d3e9617dd1e9374d07cb1536cedf4bc74ae8..9cee1fcc4b5c2b05ecc09b4f372eadeca9e91be8 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -16,7 +16,6 @@ _SINGLE_URL_WHITELIST = depset([ "arm_compiler", - "ortools_archive", ]) def _is_windows(ctx): diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD index 6da795358927f5cb8db7cb0d7ea653b80f8b5226..2876f305f1f74e8bba9a364b1ef582f42c72c313 100644 --- a/third_party/sqlite.BUILD +++ b/third_party/sqlite.BUILD @@ -5,6 +5,7 @@ licenses(["unencumbered"]) # Public Domain SQLITE_COPTS = [ "-Os", + "-DSQLITE_ENABLE_JSON1", "-DHAVE_DECL_STRERROR_R=1", "-DHAVE_STDINT_H=1", "-DHAVE_INTTYPES_H=1", diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..fc3183a754369fc30dbce40c2bf7b6828ea497c3 --- /dev/null +++ b/third_party/toolchains/BUILD @@ -0,0 +1,22 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +# Platform for use with remote execution with +# custom container based off RBE Ubuntu16_04 +# http://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04 +# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cpu +platform( + name = "rbe_ubuntu16_04-tf", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains//constraints:xenial", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/asci-toolchain/nosla-ubuntu16_04-tf@sha256:800a7b68cabef15419695c188ed33ed70adf678c2371b97b236f3ae26c38274d" + }""", +) diff --git a/third_party/toolchains/clang6/CROSSTOOL.tpl b/third_party/toolchains/clang6/CROSSTOOL.tpl index 6b7e5a88086f8e5e67fa86a0e9377c3c2afd535d..ffba9850bb80a880d5b95afacbad296ec1f2df54 100644 --- a/third_party/toolchains/clang6/CROSSTOOL.tpl +++ b/third_party/toolchains/clang6/CROSSTOOL.tpl @@ -76,9 +76,6 @@ toolchain { # This adds a little bit more durability to our Clang build. # - # At the moment, this only only be needed for: - # - add_boringssl_s390x.patch: --Wa,--noexecstack - # # Folks who do maintenance work on TF Bazel Clang should consider # commenting out these lines, while doing that work, to gain a better # understanding of what the intersection of support looks like between GCC